diff --git a/requirements/curator.txt b/requirements/curator.txt new file mode 100644 index 00000000..33b874c9 --- /dev/null +++ b/requirements/curator.txt @@ -0,0 +1,28 @@ +cd ray-api + +# pip install cosmos-xenna[gpu] +git clone https://github.com/NVIDIA-NeMo/Curator.git +git switch ray-api +pip install . + +# install NeMo +pip install "nemo_toolkit[all]" + +# install nemo_text_processing +pip install nemo_text_processing + +pip install -r requirements/main.txt +pip install -r requirements/tests.txt + +RAY_ADDRESS=10.110.41.40:8265 python -m pytest tests/test_curator.py + +# pip install loguru +# pip install -U "ray[default]" + +# cd ~/workspace/Curator/ray-curator && pip install . +# ray start --include-dashboard=True --head # ray status # ray stop + # import ray + # ray.init() + # RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --working-dir . -- python my_script.py + +RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES="0" RAY_MAX_LIMIT_FROM_API_SERVER=40000 RAY_MAX_LIMIT_FROM_DATA_SOURCE=40000 ray start --include-dashboard=True --dashboard-host=0.0.0.0 --port=8265 --dashboard-port=8266 --head --temp-dir=/tmp diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index ab0d05ed..37f38216 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -22,11 +22,11 @@ TrainDevTestSplitCORAAL, ) from sdp.processors.datasets.earnings import ( - CreateInitialAudioAndManifest, + ApplyEarnings21Normalizations, CreateFullAudioManifestEarnings21, - SpeakerSegmentedManifest, + CreateInitialAudioAndManifest, CreateSentenceSegmentedManifest, - ApplyEarnings21Normalizations, + SpeakerSegmentedManifest, ) from sdp.processors.datasets.fleurs.create_initial_manifest import ( CreateInitialManifestFleurs, @@ -82,23 +82,26 @@ CreateInitialManifestHuggingFace, ) from sdp.processors.huggingface.speech_recognition import ASRTransformers +from sdp.processors.manage_files.convert_audio import FfmpegConvert, SoxConvert +from sdp.processors.manage_files.extract import ExtractTar +from sdp.processors.manage_files.remove import RemoveFiles from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, ChangeToRelativePath, CombineSources, + DropSpecifiedFields, DuplicateFields, KeepOnlySpecifiedFields, RenameFields, SortManifest, SplitOnFixedDuration, Subprocess, - DropSpecifiedFields, - ) from sdp.processors.modify_manifest.create_manifest import ( CreateCombinedManifests, CreateInitialManifestByExt, + SaveJsonl, ) from sdp.processors.modify_manifest.data_to_data import ( ASRFileCheck, @@ -109,6 +112,8 @@ GetWER, InsIfASRInsertion, InverseNormalizeText, + LambdaExpression, + ListToEntries, MakeSentence, NormalizeText, ReadDocxLines, @@ -117,8 +122,6 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, - ListToEntries, - LambdaExpression, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, @@ -141,16 +144,6 @@ from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, ) -from sdp.processors.manage_files.convert_audio import ( - FfmpegConvert, - SoxConvert, -) -from sdp.processors.manage_files.extract import ( - ExtractTar, -) -from sdp.processors.manage_files.remove import ( - RemoveFiles, -) from sdp.processors.nemo.asr_inference import ASRInference from sdp.processors.nemo.estimate_bandwidth import EstimateBandwidth from sdp.processors.nemo.lid_inference import AudioLid diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index a4257e53..5cad8cea 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -21,6 +21,8 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +from ray_curator.stages.base import ProcessingStage +from ray_curator.tasks import Task from tqdm import tqdm from tqdm.contrib.concurrent import process_map @@ -28,14 +30,23 @@ @dataclass -class DataEntry: +class DataEntry(Task[list]): """A wrapper for data entry + any additional metrics.""" - data: Optional[Dict] # can be None to drop the entry - metrics: Any = None + def __init__(self, data: Dict = None, metrics: Any = None, dataset_name: str = "", task_id: int = 0, **kwargs): + self.data = data # data can be None to drop the entry + self.metrics = metrics + super().__init__(data=data, task_id=task_id, dataset_name=dataset_name, **kwargs) + @property + def num_items(self) -> int: + return 1 -class BaseProcessor(ABC): + def validate(self) -> bool: + return True + + +class BaseProcessor(ProcessingStage[Task, Task]): """Abstract class for SDP processors. All processor classes inherit from the ``BaseProcessor`` class. @@ -57,8 +68,9 @@ class BaseProcessor(ABC): as ``input_manifest_file``. """ - def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] = None, **kwargs): - + def __init__( + self, output_manifest_file: Optional[str] = None, input_manifest_file: Optional[str] = None, **kwargs + ): if output_manifest_file and input_manifest_file and (output_manifest_file == input_manifest_file): # we cannot have the same input and output manifest file specified because we need to be able to # read from the input_manifest_file and write to the output_manifest_file at the same time @@ -68,7 +80,7 @@ def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] self.input_manifest_file = input_manifest_file @abstractmethod - def process(self): + def process(self, tasks: Task) -> Task: """Should be overriden by the child classes to implement some data processing.""" pass @@ -82,6 +94,17 @@ def test(self): There are not tests by default. """ + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + @property + def name(self) -> str: + return "BaseProcessor" + + class BaseParallelProcessor(BaseProcessor): """ A processor that performs per-entry processing in parallel (using Dask or multiprocessing). @@ -93,15 +116,15 @@ class BaseParallelProcessor(BaseProcessor): chunksize (int): Chunk size used for parallel routines. in_memory_chunksize (int): Maximum number of entries to load at once. test_cases (list[dict]): Optional list of test cases. - use_dask (bool): If True, use Dask for parallelization; otherwise, use multiprocessing. - dask_client: (Optional) An existing Dask client. + use_backend (str): Use {None, dask, curator} for parallelization. Use None for multiprocessing. + backend_client: (Optional) An existing backend client. """ - + def __getstate__(self): state = self.__dict__.copy() # Remove the Dask client from state (it is not picklable) - if 'dask_client' in state: - state['dask_client'] = None + if 'backend_client' in state: + state['backend_client'] = None return state def __init__( @@ -112,11 +135,11 @@ def __init__( chunksize: int = 100, in_memory_chunksize: int = 100000, test_cases: Optional[List[Dict]] = None, - use_dask: bool = True, - dask_client=None, + use_backend: Optional[str] = None, + backend_client=None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_backend", None) # super().__init__(input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, **kwargs) if max_workers == -1: max_workers = os.cpu_count() @@ -127,41 +150,49 @@ def __init__( self.total_duration = 0 self.start_time = time.time() self.test_cases = test_cases or [] - self.use_dask = use_dask - self.dask_client = dask_client - + self.use_backend = use_backend + self.backend_client = backend_client + def prepare(self): - """Can be used in derived classes to prepare the processing. - - """ + """Can be used in derived classes to prepare the processing.""" pass - def process(self): - """A fork in the road to pick dask or classic processing - - """ + def process(self, tasks: Task) -> Task: + """A fork in the road to pick dask or classic processing""" os.environ.setdefault("PATH", os.defpath) self.prepare() - + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) metrics = [] - - #Ability to work sa legacy and as dask - if self.use_dask: - self._process_with_dask(metrics) + + # Ability to work sa legacy and as dask + if self.use_backend == "curator": + tasks = self._process_with_ray(metrics) else: - self._process_with_multiprocessing(metrics) + tasks = self._process_with_multiprocessing(metrics) self.finalize(metrics) + return tasks + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [] + + @property + def name(self) -> str: + return "BaseParallelProcessor" def _process_with_dask(self, metrics): import dask.bag as db from dask.distributed import Client - if self.dask_client is None: - self.dask_client = Client() - client = self.dask_client - from sdp.logging import logger + if self.backend_client is None: + self.backend_client = Client() + client = self.backend_client + from sdp.logging import logger + logger.info(f"Using Dask client with dashboard at: {client.dashboard_link}") # Delegate manifest reading to read_manifest() which returns a Dask bag. @@ -188,7 +219,27 @@ def _process_with_dask(self, metrics): self.total_duration += entry.data.get("duration", 0) logger.info(f"Processed {total_entries} entries using Dask.") + def _process_with_ray(self, metrics): + if self.output_manifest_file: + fout = open(self.output_manifest_file, "wt", encoding="utf8") + tasks = [] + for manifest_chunk in self._chunk_manifest(): + for row in manifest_chunk: + data = self.process_dataset_entry(row) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + if self.output_manifest_file: + json.dump(data_entry.data, fout, ensure_ascii=False) + fout.write("\n") + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + tasks.extend(data) + return tasks + def _process_with_multiprocessing(self, metrics): + data = [] with open(self.output_manifest_file, "wt", encoding="utf8") as fout: for manifest_chunk in self._chunk_manifest(): data = itertools.chain( @@ -207,13 +258,14 @@ def _process_with_multiprocessing(self, metrics): fout.write("\n") self.number_of_entries += 1 self.total_duration += data_entry.data.get("duration", 0) + return data def _chunk_manifest(self): """Splits the input manifest into chunks of in_memory_chunksize size. - Only used in non-Dask (multiprocessing) mode. + Only used in non-Dask (multiprocessing) mode. """ manifest_chunk = [] - # When use_dask is False, read_manifest() returns an iterator. + # When use_backend is False, read_manifest() returns an iterator. for idx, data_entry in enumerate(self.read_manifest(), 1): manifest_chunk.append(data_entry) if idx % self.in_memory_chunksize == 0: @@ -225,38 +277,51 @@ def _chunk_manifest(self): def read_manifest(self): """ Reads entries from the input manifest. - + Behavior depends on the parallelization mode: - - When use_dask is True: + - When use_backend is "dask": If the input_manifest_file exists and is non-empty, returns a Dask bag (reading in 256KB blocks). Otherwise, logs the condition and returns an empty Dask bag. - - When use_dask is False: + - When use_backend is "curator": + ToDo + - When use_backend is None: If the input_manifest_file does not exist or is empty, logs the condition and returns an empty iterator. Otherwise, opens the file in text mode, strips each line, and yields the parsed JSON from non-empty lines. - + This unified behavior lets the processor run even in manifest-creation mode. """ - from sdp.logging import logger - if self.use_dask: + from sdp.logging import logger + + if self.use_backend == "dask": import dask.bag as db - if self.input_manifest_file and os.path.exists(self.input_manifest_file) and os.path.getsize(self.input_manifest_file) > 0: + + if ( + self.input_manifest_file + and os.path.exists(self.input_manifest_file) + and os.path.getsize(self.input_manifest_file) > 0 + ): bag = db.read_text(self.input_manifest_file, blocksize=2**18).map(json.loads) return bag else: - logger.info("No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation.") + logger.info( + "No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation." + ) return db.from_sequence([]) else: if not self.input_manifest_file or not os.path.exists(self.input_manifest_file): - logger.info("No input manifest file provided or file does not exist. Continuing with an empty manifest.") + logger.info( + "No input manifest file provided or file does not exist. Continuing with an empty manifest." + ) return iter([]) - else: - #if use_dask = False, we get here - def generator(): #Reading manifest line by line, adding only non emply lines + else: + # if self.use_backend = None, we get here + def generator(): # Reading manifest line by line, adding only non emply lines with open(self.input_manifest_file, "rt", encoding="utf8") as fin: for line in fin: - if line: - yield json.loads(line) + if line: + yield json.loads(line) + return generator() @abstractmethod @@ -270,38 +335,43 @@ def process_dataset_entry(self, data_entry) -> List[Any]: def finalize(self, metrics: List[Any]): """Outputs metrics about the processed data.""" from sdp.logging import logger + logger.info("Total number of entries after processing: %d", self.number_of_entries) if self.total_duration: logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Processor completed in (seconds): %.2f", elapsed) def test(self): - """Applies processing to each test case and raises an error if the output does not match expected output.""" + """Applies processing to each test case and raises an error if the output does not match expected output.""" for test_case in self.test_cases: input_data = test_case["input"].copy() if isinstance(test_case["input"], dict) else test_case["input"] generated_outputs = self.process_dataset_entry(input_data) - expected_outputs = [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + expected_outputs = ( + [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + ) for gen_out, exp_out in zip(generated_outputs, expected_outputs): gen_data = gen_out.data if hasattr(gen_out, "data") else gen_out if gen_data != exp_out: raise RuntimeError( - "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}" - .format(test_case["input"], gen_data, exp_out) + "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}".format( + test_case["input"], gen_data, exp_out + ) ) - # ------------------ Legacy Parallel Processor ------------------ #Just for reference class LegacyParallelProcessor(BaseProcessor): """ A legacy parallel processor implementation using multiprocessing and process_map. - + This class processes the manifest in chunks (using process_map) and is provided for compatibility. Child classes must implement process_dataset_entry(). - + Args: max_workers (int): maximum number of workers that will be spawned during the parallel processing. @@ -312,12 +382,13 @@ class LegacyParallelProcessor(BaseProcessor): test_cases (list[dict]): an optional list of dicts containing test cases for checking that the processor makes the changes that we are expecting. - + The dicts must have a key ``input``, the value of which is a dictionary containing data which is our test's input manifest line, and a key ``output``, the value of which is a dictionary containing data which is the expected output manifest line. """ + def __init__( self, max_workers: int = -1, @@ -326,7 +397,7 @@ def __init__( test_cases: Optional[List[Dict]] = None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_backend", None) # super().__init__(**kwargs) if max_workers == -1: max_workers = multiprocessing.cpu_count() @@ -478,9 +549,12 @@ def finalize(self, metrics): if self.total_duration: logger.info("Total audio duration (hours) after processing (legacy): %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Legacy processor completed in (seconds): %.2f", elapsed) + def test(self): """Applies processing to "test_cases" and raises an error in case of mismatch.""" for test_case in self.test_cases: @@ -498,4 +572,4 @@ def test(self): f"Test input: {test_case['input']}\n" f"Generated output: {generated_output}\n" f"Expected output: {expected_output}" - ) \ No newline at end of file + ) diff --git a/sdp/processors/datasets/fleurs/create_initial_manifest.py b/sdp/processors/datasets/fleurs/create_initial_manifest.py index d571593a..7967c5b9 100644 --- a/sdp/processors/datasets/fleurs/create_initial_manifest.py +++ b/sdp/processors/datasets/fleurs/create_initial_manifest.py @@ -20,6 +20,8 @@ import typing from urllib.parse import parse_qs, urlparse +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor, DataEntry from sdp.utils.common import download_file, extract_archive @@ -145,6 +147,7 @@ def download_extract_files(self, dst_folder: str) -> None: os.remove(file_path) print(f'File {file_name} already exists in {target_folder}, deleted from source.') - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py index b4cd5a8b..dff513a3 100644 --- a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py +++ b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py @@ -15,6 +15,8 @@ import json from pathlib import Path + +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.processors.base_processor import BaseProcessor @@ -49,7 +51,7 @@ def __init__( super().__init__(**kwargs) self.error_file = Path(error_file) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: chapter_rows = load_manifest(self.error_file) audio_files_to_remove = set() for chapter_row in chapter_rows: @@ -64,3 +66,4 @@ def process(self): output_line = f"{json.dumps(row, ensure_ascii=False)}\n" output_f.write(output_line) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/librispeech/create_initial_manifest.py b/sdp/processors/datasets/librispeech/create_initial_manifest.py index 83d42bde..86fec1dc 100644 --- a/sdp/processors/datasets/librispeech/create_initial_manifest.py +++ b/sdp/processors/datasets/librispeech/create_initial_manifest.py @@ -18,6 +18,8 @@ import os import typing +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import download_file, extract_archive @@ -94,7 +96,7 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: entries = [] root = os.path.dirname(file_path) - print(f"Processing transcript file: {file_path}") + print(f"Processing transcript file: {file_path}") with open(file_path, encoding="utf-8") as fin: for line in fin: id, text = line[: line.index(" ")], line[line.index(" ") + 1 :] @@ -135,6 +137,7 @@ def download_extract_files(self, dst_folder: str) -> None: data_file = f'{dst_folder}/{self.split}.tar.gz' extract_archive(str(data_file), str(dst_folder), force_extract=True) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/mls/restore_pc.py b/sdp/processors/datasets/mls/restore_pc.py index 33ff22b0..a79116fc 100644 --- a/sdp/processors/datasets/mls/restore_pc.py +++ b/sdp/processors/datasets/mls/restore_pc.py @@ -24,6 +24,7 @@ import regex from joblib import Parallel, delayed +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger @@ -454,7 +455,7 @@ def __init__( self.n_jobs = n_jobs self.show_conversion_breakdown = show_conversion_breakdown - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """Main processing happens here. * Download & extract lv_text. @@ -604,3 +605,4 @@ def process(self): with open(manifest, "r") as fin: for line in fin: fout.write(line) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/slr140/create_initial_manifest.py b/sdp/processors/datasets/slr140/create_initial_manifest.py index 2da79027..2d701bf5 100644 --- a/sdp/processors/datasets/slr140/create_initial_manifest.py +++ b/sdp/processors/datasets/slr140/create_initial_manifest.py @@ -19,6 +19,7 @@ import numpy as np import sox +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from tqdm.contrib.concurrent import thread_map @@ -145,7 +146,7 @@ def __init__(self, data_split: str, split_audio_dir: str, **kwargs): self.data_split = data_split self.split_audio_dir = split_audio_dir - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: with open(self.input_manifest_file, "rt", encoding="utf8") as fin: manifest_data = [json.loads(line) for line in fin.readlines()] @@ -190,6 +191,7 @@ def process(self): logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) def _accumulate_samples( self, manifest_data: List[dict], sample_idxs: List[int], duration_threshold: int diff --git a/sdp/processors/datasets/slr83/create_initial_manifest.py b/sdp/processors/datasets/slr83/create_initial_manifest.py index 030360f7..f5edd887 100644 --- a/sdp/processors/datasets/slr83/create_initial_manifest.py +++ b/sdp/processors/datasets/slr83/create_initial_manifest.py @@ -19,6 +19,7 @@ import numpy as np import sox +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger @@ -192,7 +193,7 @@ def __init__(self, dialect, data_split, **kwargs): self.dialect = dialect self.data_split = data_split - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: with open(self.input_manifest_file, "rt", encoding="utf8") as fin: manifest_data = [json.loads(line) for line in fin.readlines()] @@ -238,6 +239,7 @@ def process(self): logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) def _accumulate_samples( self, manifest_data: List[dict], sample_idxs: List[int], duration_threshold: int diff --git a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py index 27117f2a..96ecb787 100644 --- a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py +++ b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py @@ -16,11 +16,14 @@ import json import os import typing + import gdown +import pandas as pd +from ray_curator.tasks import DocumentBatch, EmptyTask, Task, _EmptyTask -from sdp.processors.base_processor import BaseProcessor -from sdp.utils.common import extract_archive from sdp.logging import logger +from sdp.processors.base_processor import BaseProcessor, DataEntry +from sdp.utils.common import extract_archive, load_manifest, save_manifest class CreateInitialManifestUzbekvoice(BaseProcessor): @@ -30,7 +33,7 @@ class CreateInitialManifestUzbekvoice(BaseProcessor): Will download all files, extract them, and create a manifest file with the "audio_filepath", "text" and "duration" fields. - Args: + Args: raw_data_dir (str): Path to the folder where the data archive should be downloaded and extracted. Returns: @@ -59,8 +62,10 @@ def download_extract_files(self, dst_folder: str) -> None: # for big files google drive doesn't allow to try downlaoding them more than once # so, in case of receiveing gdown error we need to download them manually - #check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder - if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists(os.path.join(dst_folder, 'uzbekvoice-dataset.zip')): + # check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder + if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists( + os.path.join(dst_folder, 'uzbekvoice-dataset.zip') + ): print("Files already exist in the folder. Skipping download.") else: print(f"Downloading files from {self.URL}...") @@ -74,7 +79,6 @@ def download_extract_files(self, dst_folder: str) -> None: extract_archive(file, str(dst_folder), force_extract=True) print(f"Extracted {file}") - def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: """ Parse transcript JSON file and put it inside manifest. @@ -93,13 +97,8 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: utter_length = entry["clip_duration"] number_of_entries += 1 entries.append( - { - "audio_filepath": os.path.abspath(audio_file), - "text": transcript, - "duration": utter_length - } + {"audio_filepath": os.path.abspath(audio_file), "text": transcript, "duration": utter_length} ) - logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) @@ -109,12 +108,11 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: def process_data(self, data_folder: str, manifest_file: str) -> None: entries = self.process_transcript(os.path.join(data_folder, "uzbekvoice-dataset", "voice_dataset.json")) - with open(manifest_file, "w", encoding="utf-8") as fout: - for m in entries: - fout.write(json.dumps(m, ensure_ascii=False) + "\n") - - + if self.output_manifest_file: + save_manifest(entries, manifest_file) + return entries - def process(self): + def process(self, _: Task) -> DataEntry: self.download_extract_files(self.raw_data_dir) - self.process_data(self.raw_data_dir, self.output_manifest_file) + entries = self.process_data(self.raw_data_dir, self.output_manifest_file) + return entries diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/huggingface/speech_recognition.py index 2e64e7c4..9db7bb86 100644 --- a/sdp/processors/huggingface/speech_recognition.py +++ b/sdp/processors/huggingface/speech_recognition.py @@ -14,13 +14,15 @@ import json from pathlib import Path +from typing import Optional +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest -from typing import Optional + class ASRTransformers(BaseProcessor): """This processor transcribes audio files using HuggingFace ASR Transformer models. @@ -99,7 +101,7 @@ def __init__( # Check if using Whisper/Seamless or NVIDIA model based on the model name self.is_whisper_or_seamless = any(x in self.pretrained_model.lower() for x in ['whisper', 'seamless']) - + # Only set language in generation config for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language: self.model.generation_config.language = self.generate_language @@ -119,7 +121,7 @@ def __init__( device=self.device, ) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: json_list = load_manifest(Path(self.input_manifest_file)) json_list_sorted = sorted(json_list, key=lambda d: d[self.input_duration_key], reverse=True) @@ -131,7 +133,7 @@ def process(self): batch = json_list_sorted[start_index : start_index + self.batch_size] start_index += self.batch_size audio_files = [item[self.input_audio_key] for item in batch] - + # Only pass generate_kwargs for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language and self.generate_task: results = self.pipe( @@ -143,3 +145,4 @@ def process(self): for i, item in enumerate(batch): item[self.output_text_key] = results[i]["text"] f.write(json.dumps(item, ensure_ascii=False) + "\n") + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/langs/armenian.py b/sdp/processors/langs/armenian.py index 586807ed..319a3bd3 100644 --- a/sdp/processors/langs/armenian.py +++ b/sdp/processors/langs/armenian.py @@ -16,6 +16,7 @@ from pathlib import Path import pandas as pd +from ray_curator.tasks import _EmptyTask from sdp.processors.base_processor import ( BaseParallelProcessor, @@ -62,9 +63,10 @@ class MakeTsv(BaseProcessor): """ - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: df1 = pd.DataFrame.from_records(load_manifest(Path(self.input_manifest_file))) df1.to_csv(self.output_manifest_file, index=None, sep='\t') + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) class RandomTsvPart(BaseProcessor): @@ -88,8 +90,9 @@ def __init__( self.part = part self.random_state = random_state - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: df1 = pd.read_csv(self.input_manifest_file, sep='\t') df1.sample(frac=self.part, random_state=self.random_state).to_csv( self.output_manifest_file, index=None, sep='\t' ) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/modify_manifest/common.py b/sdp/processors/modify_manifest/common.py index c94c72bb..2b2febc0 100644 --- a/sdp/processors/modify_manifest/common.py +++ b/sdp/processors/modify_manifest/common.py @@ -69,7 +69,7 @@ def __init__( self.arg_separator = arg_separator self.cmd = cmd - def process(self): + def process(self, tasks: DataEntry) -> DataEntry: os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) if self.cmd.find(self.input_manifest_file) != -1 or self.cmd.find(self.output_manifest_file) != -1: logger.error( @@ -92,7 +92,7 @@ def process(self): if self.output_manifest_arg: process_args.extend([self.output_manifest_arg + self.arg_separator + self.output_manifest_file]) subprocess.run(" ".join(process_args), shell=True) - + return tasks class CombineSources(BaseParallelProcessor): @@ -401,7 +401,7 @@ def process(self): fout.write(json.dumps(line, ensure_ascii=False) + "\n") -class KeepOnlySpecifiedFields(BaseProcessor): +class KeepOnlySpecifiedFields(BaseParallelProcessor): """Saves a copy of a manifest but only with a subset of the fields. Typically will be the final processor to save only relevant fields @@ -421,14 +421,9 @@ def __init__(self, fields_to_keep: List[str], **kwargs): super().__init__(**kwargs) self.fields_to_keep = fields_to_keep - def process(self): - with open(self.input_manifest_file, "rt", encoding="utf8") as fin, open( - self.output_manifest_file, "wt", encoding="utf8" - ) as fout: - for line in tqdm(fin): - line = json.loads(line) - new_line = {field: line[field] for field in self.fields_to_keep} - fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") + def process_dataset_entry(self, data_entry: Dict): + new_data_entry = {field: data_entry[field] for field in self.fields_to_keep} + return [DataEntry(data=new_data_entry)] class ApplyInnerJoin(BaseProcessor): @@ -460,21 +455,41 @@ def __init__( self.right_manifest_file = right_manifest_file self.column_id = column_id - def process(self): - m1 = pd.DataFrame.from_records(load_manifest(Path(self.left_manifest_file))) - m2 = pd.DataFrame.from_records(load_manifest(Path(self.right_manifest_file))) + def process(self, tasks: DataEntry, tasks2: DataEntry = None) -> DataEntry: + if self.left_manifest_file: + m1 = pd.DataFrame.from_records(load_manifest(Path(self.left_manifest_file))) + elif tasks: + logger.warning("batch_size should be as big as the dataset size") + m1 = tasks.toDataFrame() + else: + raise ValueError("tasks or self.input_manifest_file or self.left_manifest_file must be not None") + + if self.right_manifest_file: + m2 = pd.DataFrame.from_records(load_manifest(Path(self.right_manifest_file))) + elif tasks2: + logger.warning("batch_size should be as big as the dataset size") + m2 = tasks.toDataFrame() + else: + raise ValueError("tasks2 or self.right_manifest_file must be not None") + m3 = pd.merge(m1, m2, on=self.column_id, how="inner") - with open(self.output_manifest_file, "wt", encoding="utf8") as fout: - for _, line in m3.iterrows(): + if self.output_manifest_file: + fout = open(self.output_manifest_file, "wt", encoding="utf8") + items = [] + for _, line in m3.iterrows(): + item = DataEntry(data=dict(line)) + if self.output_manifest_file: fout.write(json.dumps(dict(line), ensure_ascii=False) + "\n") + items.append(item) + return items class DropSpecifiedFields(BaseProcessor): """ A processor that removes specified fields from each data entry in the manifest. - This processor reads an input manifest line by line, drops the fields listed in `fields_to_drop` + This processor reads an input manifest line by line, drops the fields listed in `fields_to_drop` from each JSON entry, and writes the cleaned entries to the output manifest. Args: @@ -502,4 +517,4 @@ def process(self): # Create a new entry by excluding the specified fields new_line = {field: entry[field] for field in entry if field not in self.fields_to_drop} # Write the cleaned entry to the output manifest - fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") \ No newline at end of file + fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index 1248f3b0..c1905b96 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -15,9 +15,39 @@ import json from pathlib import Path -import pandas +from ray_curator.stages.base import ProcessingStage +from ray_curator.stages.resources import Resources +from ray_curator.tasks import Task -from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +from sdp.processors.base_processor import ( + BaseParallelProcessor, + BaseProcessor, + DataEntry, +) + + +class SaveJsonl(BaseProcessor): + """ + Processor for saving tasks as a one JSONL file. + + Args: + **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. + + """ + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + + def setup_on_node(self, _, __): + open(self.output_manifest_file, 'w').close() + + def process(self, tasks: DataEntry) -> DataEntry: + with open(self.output_manifest_file, 'a', encoding="utf8") as f: + f.write(json.dumps(tasks.data) + '\n') + return tasks class CreateInitialManifestByExt(BaseParallelProcessor): @@ -44,6 +74,10 @@ def __init__( self.output_file_key = output_file_key self.extension = extension + def setup_on_node(self, _, __): + if self.output_manifest_file: + open(self.output_manifest_file, 'w').close() + def read_manifest(self): # Get all files with the specified extension files = list(self.raw_data_dir.rglob('*.' + self.extension)) @@ -52,7 +86,12 @@ def read_manifest(self): def process_dataset_entry(self, data_entry): data = {self.output_file_key: data_entry} - return [DataEntry(data=data)] + return [ + DataEntry( + data=data, + dataset_name=str(self.raw_data_dir / "*.") + self.extension, + ) + ] class CreateCombinedManifests(BaseParallelProcessor): @@ -85,4 +124,10 @@ def read_manifest(self): yield json.loads(line) def process_dataset_entry(self, data_entry): - return [DataEntry(data=data_entry)] + return [ + DataEntry( + data=data_entry, + task_id=0, + dataset_name=self.__class__.__name__, + ) + ] diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 35b4d5b0..96fabe5c 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import json import os import re from typing import Dict, List, Optional @@ -20,8 +21,9 @@ import soundfile import torchaudio from docx import Document +from ray_curator.tasks import _EmptyTask +from sox import Transformer from tqdm import tqdm -import json from sdp.logging import logger from sdp.processors.base_processor import ( @@ -29,11 +31,11 @@ BaseProcessor, DataEntry, ) +from sdp.utils.apply_operators import evaluate_expression from sdp.utils.common import ffmpeg_convert from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces from sdp.utils.get_diff import get_diff_with_subs_grouped from sdp.utils.metrics_computation import get_wer -from sdp.utils.apply_operators import evaluate_expression class GetAudioDuration(BaseParallelProcessor): @@ -105,6 +107,82 @@ def process_dataset_entry(self, data_entry): return data_list +class SoxConvert(BaseParallelProcessor): + """Processor for Sox to convert audio files to specified format. + + Args: + output_manifest_file (str): Path to the output manifest file. + input_audio_file_key (str): Key in the manifest file that contains the path to the input audio file. + output_audio_file_key (str): Key in the manifest file that contains the path to the output audio file. + converted_audio_dir (str): Path to the directory where the converted audio files will be stored. + output_format (str): Format of the output audio file. + rate (int): Sample rate of the output audio file. + channels (int): Number of channels of the output audio file. + workspace_dir (str, Optional): Path to the workspace directory. Defaults to None. + """ + + def __init__( + self, + converted_audio_dir: str, + input_audio_file_key: str = "audio_filepath", + output_audio_file_key: str = "audio_filepath", + output_format: str = "wav", + rate: int = 16000, + channels: int = 1, + workspace_dir: Optional[str] = None, + **kwargs, + ): + # Extract workspace_dir from kwargs to avoid passing it to BaseProcessor + if "workspace_dir" in kwargs: + workspace_dir = kwargs.pop("workspace_dir") + + super().__init__(**kwargs) + self.input_audio_file_key = input_audio_file_key + self.output_audio_file_key = output_audio_file_key + self.converted_audio_dir = converted_audio_dir + self.output_format = output_format + self.workspace_dir = workspace_dir + + # Store the new parameters for later use: + self.rate = rate + self.channels = channels + + def prepare(self): + # Debug print for workspace_dir + logger.info(f"SoxConvert workspace_dir: {self.workspace_dir}") + os.makedirs(self.converted_audio_dir, exist_ok=True) + + def process_dataset_entry(self, data_entry): + audio_path = data_entry[self.input_audio_file_key] + + # If workspace_dir is provided, join it with audio_path to get absolute path + if self.workspace_dir is not None: + full_audio_path = os.path.join(self.workspace_dir, audio_path) + else: + full_audio_path = audio_path + + # Debug print first file path + if not hasattr(self, '_debug_printed'): + logger.info(f"First audio_path from manifest: {audio_path}") + logger.info(f"First full_audio_path: {full_audio_path}") + logger.info(f"Path exists: {os.path.exists(full_audio_path)}") + self._debug_printed = True + + key = os.path.splitext(audio_path)[0].split("/")[-1] + converted_file = os.path.join(self.converted_audio_dir, key) + f".{self.output_format}" + + if not os.path.isfile(converted_file): + transformer = Transformer() + + transformer.rate(self.rate) + transformer.channels(self.channels) + + transformer.build(full_audio_path, converted_file) + + data_entry[self.output_audio_file_key] = converted_file + return [DataEntry(data=data_entry)] + + class CountNumWords(BaseParallelProcessor): """ Processor for counting the number of words in the text_key field saving the number in num_words_key. @@ -460,10 +538,10 @@ def __init__( super().__init__(**kwargs) if not regex_params_list and not regex_params_yaml: raise ValueError(f'One of `regex_params_list` or `regex_params_yaml` should be provided.') - + self.regex_params_list = regex_params_list if regex_params_yaml: - with open(regex_params_yaml, 'r') as regex_params_file: + with open(regex_params_yaml, 'r') as regex_params_file: self.regex_params_list = yaml.safe_load(regex_params_file) self.text_key = text_key @@ -555,6 +633,7 @@ def __init__( def prepare(self): from nemo_text_processing.text_normalization.normalize import Normalizer + try: self.normalizer = Normalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -603,7 +682,10 @@ def __init__( self.verbose = verbose def prepare(self): - from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer + from nemo_text_processing.inverse_text_normalization.inverse_normalize import ( + InverseNormalizer, + ) + try: self.inverse_normalizer = InverseNormalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -624,7 +706,7 @@ class CopyManifestData(BaseParallelProcessor): Args: copy_path (str): The destination directory where files will be copied. - source_filepath (str): The key in the manifest that contains the path to + source_filepath (str): The key in the manifest that contains the path to the file to be copied. Default: "audio_path". Returns: @@ -640,6 +722,7 @@ class CopyManifestData(BaseParallelProcessor): copy_path: ${workspace_dir}/consolidated_data source_filepath: "audio_filepath" """ + def __init__( self, copy_path: str, @@ -808,15 +891,16 @@ class GetWER(BaseParallelProcessor): """This processor calculates Word Error Rate (WER) between predicted text and ground truth text. It computes the WER for each entry in the manifest and adds the result as a new field. - + Args: text_key (str): Key for the ground truth text field in the manifest. Default: "text". pred_text_key (str): Key for the predicted text field in the manifest. Default: "pred_text". - + Returns: - The same data as in the input manifest with an additional 'wer' field containing + The same data as in the input manifest with an additional 'wer' field containing the calculated Word Error Rate between the specified text fields. """ + def __init__( self, text_key: str = "text", @@ -860,6 +944,7 @@ class MakeSentence(BaseParallelProcessor): end_symbol: "." make_uppercase: true """ + def __init__( self, text_key: str = "text", @@ -899,7 +984,14 @@ class ASRFileCheck(BaseProcessor): A manifest with corrupted audio files removed. """ - def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_dir: str = None, workspace_dir: str = None, **kwargs): + + def __init__( + self, + audio_filepath_key: str = "audio_filepath", + corrupted_audio_dir: str = None, + workspace_dir: str = None, + **kwargs, + ): """ Constructs the necessary attributes for the ASRFileCheck class. @@ -915,31 +1007,33 @@ def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_d """ super().__init__(**kwargs) self.audio_filepath_key = audio_filepath_key - + if corrupted_audio_dir is None: - raise ValueError("corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files.") - + raise ValueError( + "corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files." + ) + self.corrupted_audio_dir = corrupted_audio_dir self.workspace_dir = workspace_dir self.failed_files = [] - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """ Check each file listed in the manifest to ensure it can be loaded with torchaudio. This method reads through the manifest file, attempts to load each audio file using torchaudio, and moves corrupted files. A new manifest file is created with only the valid entries. - + Specific errors handled: - FileNotFoundError: File doesn't exist - RuntimeError: File format issues or codec problems - Other exceptions: General issues with file loading """ from sdp.logging import logger - + # Debug print to show workspace_dir logger.info(f"ASRFileCheck workspace_dir: {self.workspace_dir}") - + with open(self.input_manifest_file, 'r') as f: lines = f.readlines() @@ -953,22 +1047,22 @@ def process(self): line = lines[idx] entry = json.loads(line) audio_path = entry[self.audio_filepath_key] - + # Debug print first file path if idx == 0: logger.info(f"First audio_path from manifest: {audio_path}") - + # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path - + # Debug print first full path if idx == 0: logger.info(f"First full_audio_path: {full_audio_path}") logger.info(f"Path exists: {os.path.exists(full_audio_path)}") - + try: # Attempt to load the audio file to check if it is corrupted torchaudio.load(full_audio_path) @@ -979,7 +1073,7 @@ def process(self): except RuntimeError as e: logger.warning(f"Audio format error in {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -988,7 +1082,7 @@ def process(self): except Exception as e: logger.warning(f"Unknown error loading {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -1004,6 +1098,7 @@ def process(self): if self.failed_files: logger.warning(f"Failed to process {len(self.failed_files)} files.") logger.debug(f"Failed files: {self.failed_files}") + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) class ListToEntries(BaseParallelProcessor): @@ -1023,23 +1118,23 @@ class ListToEntries(BaseParallelProcessor): Raises: TypeError: If the specified list field is not of type list. ValueError: If the list items are not dictionaries and `output_field` is not provided. - + Returns: - A manifest where each entry corresponds to one item in the original list from the input entry. - This effectively transforms a single input entry containing a list of items into multiple standalone + A manifest where each entry corresponds to one item in the original list from the input entry. + This effectively transforms a single input entry containing a list of items into multiple standalone entries, each suitable for further dataset processing. .. admonition:: Example 1 (list of dicts) - + .. code-block:: yaml - + - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "segments" - + Input:: - + { "audio_filepath": "sample.wav", "segments": [ @@ -1064,19 +1159,19 @@ class ListToEntries(BaseParallelProcessor): "text": "World" } ] - + .. admonition:: Example 2 (list of primitives) - + .. code-block:: yaml - + - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "text_chunks" output_field: "text" - + Input:: - + { "audio_filepath": "sample.wav", "text_chunks": [ @@ -1100,10 +1195,7 @@ class ListToEntries(BaseParallelProcessor): """ - def __init__(self, - field_with_list: str, - output_field: str = None, - **kwargs): + def __init__(self, field_with_list: str, output_field: str = None, **kwargs): super().__init__(**kwargs) self.field_with_list = field_with_list self.output_field = output_field @@ -1114,14 +1206,16 @@ def process_dataset_entry(self, data_entry): # Check that the target field is actually a list if not isinstance(data_entry[self.field_with_list], list): raise TypeError(f'Values of {self.field_with_list} field should be list type only: {data_entry}') - + # Remove the list field from the entry and get the list of items items_list = data_entry.pop(self.field_with_list) # If items are not dicts, output_field must be specified to store the item if not isinstance(items_list[0], dict) and not self.output_field: - raise ValueError(f'Type of items in items list `{self.field_with_list}` is not dict ({type(items_list[0])}). In this case `output_field` should be provided.') - + raise ValueError( + f'Type of items in items list `{self.field_with_list}` is not dict ({type(items_list[0])}). In this case `output_field` should be provided.' + ) + # Expand the list into multiple entries for item in items_list: _entry = data_entry.copy() @@ -1129,7 +1223,7 @@ def process_dataset_entry(self, data_entry): # If item is a dict, merge its keys; otherwise, store it in `output_field` if isinstance(item, dict): _entry.update(item) - else: + else: _entry[self.output_field] = item _entry = DataEntry(_entry) @@ -1202,6 +1296,7 @@ class LambdaExpression(BaseParallelProcessor): str: A line-delimited JSON manifest, where each line is a processed entry. The result may contain fewer entries than the input if ``filter=True``. """ + def __init__( self, new_field: str, @@ -1223,12 +1318,12 @@ def process_dataset_entry(self, data_entry) -> List[DataEntry]: If `filter` is True, the entry is only retained if the expression evaluates to True. Otherwise, the result is stored in `new_field`. """ - value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) + value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) if self.filter: if value is not True: return [] - data_entry[self.new_field] = value + data_entry[self.new_field] = value return [DataEntry(data=data_entry)] def finalize(self, metrics): - super().finalize(metrics) \ No newline at end of file + super().finalize(metrics) diff --git a/sdp/processors/modify_manifest/data_to_dropbool.py b/sdp/processors/modify_manifest/data_to_dropbool.py index eeeebd1e..8c03dc64 100644 --- a/sdp/processors/modify_manifest/data_to_dropbool.py +++ b/sdp/processors/modify_manifest/data_to_dropbool.py @@ -19,6 +19,8 @@ from operator import eq, ge, gt, le, lt, ne from typing import List, Union +from ray_curator.tasks import _EmptyTask + from sdp.logging import logger from sdp.processors.base_processor import ( BaseParallelProcessor, @@ -75,7 +77,7 @@ def __init__( 'Operator must be one from the list: "lt" (less than), "le" (less than or equal to), "eq" (equal to), "ne" (not equal to), "ge" (greater than or equal to), "gt" (greater than)' ) - def process_dataset_entry(self, data_entry): + def process_dataset_entry(self, data_entry): input_value = data_entry[self.input_value_key] target = self.target_value if self.operator(input_value, target): @@ -890,7 +892,7 @@ def __init__(self, drop_key: str = "text", **kwargs): self.drop_key = drop_key self.seen_texts = set() - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: unique_entries = [] with open(self.input_manifest_file, 'r', encoding='utf-8') as file: for line in file: @@ -905,4 +907,4 @@ def process(self): fout.write(json.dumps(entry, ensure_ascii=False) + "\n") logger.info(f"Total number of entries after processing: {len(unique_entries)}") - return unique_entries + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 4359f320..d5672191 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -17,6 +17,8 @@ from pathlib import Path from typing import Optional +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor # Note that we do not re-use base parallel implementation, since the ASR @@ -44,16 +46,16 @@ class ASRInference(BaseProcessor): def __init__( self, - pretrained_model: Optional[str]=None, + pretrained_model: Optional[str] = None, batch_size: int = 32, **kwargs, ): super().__init__(**kwargs) self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech.py" self.pretrained_model = pretrained_model - self.batch_size = batch_size + self.batch_size_asr = batch_size - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """This will add "pred_text" key into the output manifest.""" os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) if self.pretrained_model.endswith(".nemo"): @@ -62,7 +64,7 @@ def process(self): f"model_path={self.pretrained_model} " f"dataset_manifest={self.input_manifest_file} " f"output_filename={self.output_manifest_file} " - f"batch_size={self.batch_size} ", + f"batch_size={self.batch_size_asr} ", shell=True, check=True, ) @@ -72,7 +74,8 @@ def process(self): f"pretrained_name={self.pretrained_model} " f"dataset_manifest={self.input_manifest_file} " f"output_filename={self.output_manifest_file} " - f"batch_size={self.batch_size} ", + f"batch_size={self.batch_size_asr} ", shell=True, check=True, - ) \ No newline at end of file + ) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/nemo/lid_inference.py b/sdp/processors/nemo/lid_inference.py index 0f1b871f..69bd9456 100644 --- a/sdp/processors/nemo/lid_inference.py +++ b/sdp/processors/nemo/lid_inference.py @@ -5,7 +5,7 @@ from tqdm import tqdm from sdp.logging import logger -from sdp.processors.base_processor import BaseProcessor +from sdp.processors.base_processor import BaseProcessor, DataEntry from sdp.utils.common import load_manifest @@ -45,7 +45,7 @@ def __init__( self.random_seed = random_seed self.device = device - def process(self): + def process(self, tasks: DataEntry) -> DataEntry: import nemo.collections.asr as nemo_asr import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo @@ -59,7 +59,10 @@ def process(self): else: model = model.to(self.device) - manifest = load_manifest(Path(self.input_manifest_file)) + if self.input_manifest_file: + manifest = load_manifest(Path(self.input_manifest_file)) + else: + manifest = tasks.data Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) with Path(self.output_manifest_file).open('w') as f: @@ -75,3 +78,4 @@ def process(self): if lang: item[self.output_lang_key] = lang f.write(json.dumps(item, ensure_ascii=False) + '\n') + return tasks diff --git a/sdp/processors/nemo/transcribe_speech.py b/sdp/processors/nemo/transcribe_speech.py index bb04047b..05803923 100644 --- a/sdp/processors/nemo/transcribe_speech.py +++ b/sdp/processors/nemo/transcribe_speech.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,34 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v1.23.0/examples/asr/transcribe_speech.py. -# It is currently only compatible with NeMo v1.23.0. To use a different version of NeMo, please modify the file. +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/r2.4.0/examples/asr/transcribe_speech.py. +# It is currently only compatible with NeMo r2.4.0. To use a different version of NeMo, please modify the file. -import contextlib +import json import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl +import numpy as np import torch -from omegaconf import OmegaConf, open_dict - -from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel +from nemo.collections.asr.models import ( + EncDecCTCModel, + EncDecHybridRNNTCTCModel, + EncDecRNNTModel, +) +from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig -from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig +from nemo.collections.asr.parts.submodules.multitask_decoding import ( + MultiTaskDecoding, + MultiTaskDecodingConfig, +) from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, prepare_audio_data, + restore_transcription_order, setup_model, - transcribe_partial_audio, write_transcription, ) from nemo.core.config import hydra_runner from nemo.utils import logging +from nemo.utils.timers import SimpleTimer +from omegaconf import OmegaConf, open_dict """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. @@ -48,21 +57,17 @@ model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files - dataset_manifest: path to dataset JSON manifest file (in NeMo format) - - compute_timestamps: Bool to request greedy time stamp information (if the model supports it) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) - - (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) + ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) + rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) output_filename: Output filename where the transcriptions will be written batch_size: batch size during inference + presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) cuda: Optional int to enable or disable execution of model on certain CUDA device. allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available @@ -79,6 +84,8 @@ langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". @@ -95,7 +102,7 @@ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ - compute_timestamps=False \ + timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ @@ -106,13 +113,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder - conformer: ConformerChangeConfig = ConformerChangeConfig() + conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: + """ + Transcription Configuration for audio to text transcription. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -123,6 +136,7 @@ class TranscriptionConfig: ] = None # Used to select a single channel from multichannel audio, or use average across channels audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation + presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction # General configs output_filename: Optional[str] = None @@ -132,10 +146,11 @@ class TranscriptionConfig: pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() - # Set to True to output greedy timestamp information (only supported models) - compute_timestamps: bool = False - # set to True if need to return full alignment information - preserve_alignment: bool = False + # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses + timestamps: Optional[bool] = None + + # Set to True to return hypotheses instead of text from the transcribe function + return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -147,19 +162,33 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + compute_dtype: Optional[ + str + ] = None # "float32", "bfloat16" or "float16"; if None (default): bfloat16 if available else float32 + matmul_precision: str = "high" # Literal["highest", "high", "medium"] audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True # Decoding strategy for CTC models - ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() + ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) # Decoding strategy for RNNT models - rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # enable CUDA graphs for transcription + rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) # Decoding strategy for AED models - multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() + multitask_decoding: MultiTaskDecodingConfig = field(default_factory=MultiTaskDecodingConfig) + # Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs: + # Implicit single-turn assuming default role='user' (works with Canary-1B) + # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes + # Explicit single-turn prompt: + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es + # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # Explicit multi-turn prompt: + # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' + prompt: dict = field(default_factory=dict) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None @@ -167,7 +196,7 @@ class TranscriptionConfig: att_context_size: Optional[list] = None # Use this for model-specific changes before transcription - model_change: ModelChangeConfig = ModelChangeConfig() + model_change: ModelChangeConfig = field(default_factory=ModelChangeConfig) # Config for word / character error rate calculation calculate_wer: bool = True @@ -179,20 +208,22 @@ class TranscriptionConfig: # if True, will also skip writing anything to the output file return_transcriptions: bool = False - # Set to False to return text instead of hypotheses from the transcribe function, so as to save memory - return_hypotheses: bool = True - # key for groundtruth text in manifest gt_text_attr_name: str = "text" + gt_lang_attr_name: str = "lang" + + extract_nbest: bool = False # Extract n-best hypotheses from the model - # Use model's transcribe() function instead of transcribe_partial_audio() by default - # Only use transcribe_partial_audio() when the audio is too long to fit in memory - # Your manifest input should have `offset` field to use transcribe_partial_audio() - allow_partial_transcribe: bool = False + calculate_rtfx: bool = False + warmup_steps: int = 0 # by default - no warmup + run_steps: int = 1 # by default - single run @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: + """ + Transcribes the input audio and can be used to infer with Encoder-Decoder models. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: @@ -217,6 +248,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device @@ -247,11 +279,29 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.set_trainer(trainer) asr_model = asr_model.eval() + if (cfg.compute_dtype is not None and cfg.compute_dtype != "float32") and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") + + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + + compute_dtype: torch.dtype + if cfg.compute_dtype is None: + can_use_bfloat16 = (not cfg.amp) and map_location.type == "cuda" and torch.cuda.is_bf16_supported() + if can_use_bfloat16: + compute_dtype = torch.bfloat16 + else: + compute_dtype = torch.float32 + else: + assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} + compute_dtype = getattr(torch, cfg.compute_dtype) + + asr_model.to(compute_dtype) + # we will adjust this flag if the model does not support it - compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs - # has to be True if timestamps are required - preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment + + if cfg.timestamps: + cfg.return_hypotheses = True # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -260,7 +310,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis elif isinstance(asr_model, EncDecHybridRNNTCTCModel): if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']: raise ValueError('Hybrid model only support ctc or rnnt decoding!') - else: # rnnt model, there could be other models needs to be addressed. + elif isinstance(asr_model, EncDecRNNTModel): if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') @@ -271,7 +321,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs - cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment + if cfg.extract_nbest: + cfg.multitask_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.multitask_decoding) elif cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually @@ -279,9 +331,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis raise ValueError("CTC models do not support `compute_langs` at the moment") decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding - decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it - if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = preserve_alignment + if cfg.extract_nbest: + decoding_cfg.beam.return_best_hypothesis = False + cfg.return_hypotheses = True if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -291,17 +343,19 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Check if ctc or rnnt model elif hasattr(asr_model, 'joint'): # RNNT model + if cfg.extract_nbest: + cfg.rnnt_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 - cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs - if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") - cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps + if cfg.extract_nbest: + cfg.ctc_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.ctc_decoding) @@ -311,31 +365,16 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc" ): cfg.decoding = cfg.ctc_decoding + elif isinstance(asr_model.decoding, MultiTaskDecoding): + cfg.decoding = cfg.multitask_decoding else: cfg.decoding = cfg.rnnt_decoding - if isinstance(asr_model, EncDecMultiTaskModel): - # Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function - partial_audio = False - filepaths = cfg.dataset_manifest - assert cfg.dataset_manifest is not None - else: - # prepare audio filepaths and decide wether it's partial audio - filepaths, partial_audio = prepare_audio_data(cfg) + filepaths, sorted_manifest_path = prepare_audio_data(cfg) - if not cfg.allow_partial_transcribe: - # by defatul, use model's transcribe() function, unless partial audio is required - partial_audio = False - - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: + remove_path_after_done = sorted_manifest_path if sorted_manifest_path is not None else None - @contextlib.contextmanager - def autocast(dtype=None): - yield + filepaths = sorted_manifest_path if sorted_manifest_path is not None else filepaths # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -350,37 +389,82 @@ def autocast(dtype=None): # transcribe audio - amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + if cfg.calculate_rtfx: + total_duration = 0.0 + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + if "duration" not in item: + raise ValueError( + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." + ) + total_duration += item["duration"] + + if cfg.warmup_steps == 0: + logging.warning( + "RTFx measurement enabled, but warmup_steps=0. " + "At least one warmup step is recommended to measure RTFx" + ) - with autocast(dtype=amp_dtype): + timer = SimpleTimer() + model_measurements = [] + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): - if partial_audio: - transcriptions = transcribe_partial_audio( - asr_model=asr_model, - path2manifest=cfg.dataset_manifest, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, - decoder_type=cfg.decoder_type, - ) - else: + override_cfg = asr_model.get_transcribe_config() + override_cfg.batch_size = cfg.batch_size + override_cfg.num_workers = cfg.num_workers + override_cfg.return_hypotheses = cfg.return_hypotheses + override_cfg.channel_selector = cfg.channel_selector + override_cfg.augmentor = augmentor + override_cfg.text_field = cfg.gt_text_attr_name + override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.timestamps = cfg.timestamps + if hasattr(override_cfg, "prompt"): + override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) + + device = next(asr_model.parameters()).device + for run_step in range(cfg.warmup_steps + cfg.run_steps): + if run_step < cfg.warmup_steps: + logging.info(f"Running warmup step {run_step}") + # reset timer + timer.reset() + timer.start(device=device) + # call transcribe transcriptions = asr_model.transcribe( - paths2audio_files=filepaths, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, + audio=filepaths, + override_config=override_cfg, + timestamps=cfg.timestamps, ) + # stop timer, log time + timer.stop(device=device) + logging.info(f"Model time for iteration {run_step}: {timer.total_sec():.3f}") + if run_step >= cfg.warmup_steps: + model_measurements.append(timer.total_sec()) + + model_measurements_np = np.asarray(model_measurements) + logging.info( + f"Model time avg: {model_measurements_np.mean():.3f}" + + (f" (std: {model_measurements_np.std():.3f})" if cfg.run_steps > 1 else "") + ) - logging.info(f"Finished transcribing {len(filepaths)} files !") + if cfg.dataset_manifest is not None: + logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") + if cfg.presort_manifest: + transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions) + else: + logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") - # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis + # if transcriptions form a tuple of (best_hypotheses, all_hypotheses) if type(transcriptions) == tuple and len(transcriptions) == 2: - transcriptions = transcriptions[0] + if cfg.extract_nbest: + # extract all hypotheses if exists + transcriptions = transcriptions[1] + else: + # extract just best hypothesis + transcriptions = transcriptions[0] if cfg.return_transcriptions: return transcriptions @@ -392,10 +476,15 @@ def autocast(dtype=None): model_name, filepaths=filepaths, compute_langs=compute_langs, - compute_timestamps=compute_timestamps, + timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") + # clean-up + if cfg.presort_manifest is not None: + if remove_path_after_done is not None: + os.unlink(remove_path_after_done) + if cfg.calculate_wer: output_manifest_w_wer, total_res, _ = cal_write_wer( pred_manifest=output_filename, @@ -410,8 +499,15 @@ def autocast(dtype=None): logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") + if cfg.calculate_rtfx: + rtfx_measurements = total_duration / model_measurements_np + logging.info( + f"Model RTFx on the dataset: {rtfx_measurements.mean():.3f}" + + (f" (std: {rtfx_measurements.std():.3f})" if cfg.run_steps > 1 else "") + ) + return cfg if __name__ == '__main__': - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 6ddf27f4..46707344 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import tempfile import uuid from typing import List, Optional -import psutil -import json import hydra +import psutil from omegaconf import OmegaConf, open_dict +from ray_curator.backends.xenna import XennaExecutor +from ray_curator.pipeline import Pipeline +from ray_curator.tasks import EmptyTask, _EmptyTask from sdp.logging import logger - from sdp.utils.import_manager import ImportManager # registering new resolvers to simplify config files @@ -46,16 +48,18 @@ logger.addHandler(handler) logger.propagate = False + def update_processor_imports(config_path: str, init_file: str = None): """ Update processor imports based on config file. - + Args: config_path: Path to the YAML config file init_file: Optional path to __init__.py file to update """ try: import yaml + manager = ImportManager() manager.sync_with_config(config_path, init_file) logger.info(f"Successfully updated imports for config: {config_path}") @@ -120,13 +124,14 @@ def run_processors(cfg): if cfg.get("use_import_manager", False): try: import yaml + yaml_path = cfg.get("config_path") if not yaml_path: raise ValueError("No configuration path provided in 'config_path'. Please specify the path.") if not os.path.exists(yaml_path): raise FileNotFoundError(f"Configuration file not found: {yaml_path}") - + logger.info(f"Managing imports for config: {yaml_path}") manager = ImportManager() manager.sync_with_config(yaml_path) @@ -141,22 +146,26 @@ def run_processors(cfg): except Exception as e: logger.error(f"An unexpected error occurred during management of imports: {e}") - # Detecting dask + # Detecting ray try: - from dask.distributed import Client - dask_available = True + import ray + + ray_available = True except ImportError: logger.warning("Dask not installed; using multiprocessing for all processors") - dask_available = False - + ray_available = False + # look for global directions in cfg for dask usage - global_use_dask = bool(cfg.get("use_dask", True)) and dask_available + if bool(cfg.get("use_backend", None) == "ray") and ray_available: + global_use_backend = "ray" + else: + global_use_backend = cfg.get("use_backend", None) processors_to_run = cfg.get("processors_to_run", "all") if processors_to_run == "all": processors_to_run = ":" selected_cfgs = select_subset(cfg.processors, processors_to_run) - + # filtering out any processors that have should_run=False processors_cfgs = [] for processor_cfg in selected_cfgs: @@ -169,9 +178,7 @@ def run_processors(cfg): "Specified to run the following processors: %s ", [proc_cfg["_target_"] for proc_cfg in processors_cfgs], ) - - - + processors = [] # Create a temporary directory to hold intermediate files if needed. with tempfile.TemporaryDirectory() as tmp_dir: @@ -187,7 +194,7 @@ def run_processors(cfg): with open_dict(processors_cfgs[0]): processors_cfgs[0]["input_manifest_file"] = cfg.processors[idx - 1]["output_manifest_file"] break - + for idx, processor_cfg in enumerate(processors_cfgs): logger.info('=> Building processor "%s"', processor_cfg["_target_"]) @@ -205,48 +212,59 @@ def run_processors(cfg): if idx != len(processors_cfgs) - 1 and "input_manifest_file" not in processors_cfgs[idx + 1]: with open_dict(processors_cfgs[idx + 1]): processors_cfgs[idx + 1]["input_manifest_file"] = processor_cfg["output_manifest_file"] - - #check if we have processor level directions of using dask - flag=processor_cfg.get("use_dask", None) + + # check if we have processor level directions of using dask + flag = processor_cfg.get("use_backend", None) # if no processor-specific flag, fallback to global; otherwise use provided value if flag is None: - use_dask_flag = global_use_dask + use_backend_flag = global_use_backend else: - use_dask_flag = flag + use_backend_flag = flag + processor = hydra.utils.instantiate(processor_cfg) - processor.use_dask = use_dask_flag + processor.use_backend = use_backend_flag # running runtime tests to fail right-away if something is not # matching users expectations processor.test() processors.append(processor) # Start Dask client if any processor requires it - dask_client = None - if any(p.use_dask for p in processors): - + backend_client = None + if any(p.use_backend for p in processors): try: num_cpus = psutil.cpu_count(logical=False) or 4 - logger.info(f"Starting Dask client with {num_cpus} workers") - dask_client = Client(n_workers=num_cpus, processes=True) - logger.info(f"Dask dashboard at: {dask_client.dashboard_link}") + logger.info(f"Starting Ray client with {num_cpus} workers") + backend_client = ray.init() # Client(n_workers=num_cpus, processes=True) + logger.info(f"Dask dashboard at: {backend_client.dashboard_link}") except Exception as e: logger.warning(f"Failed to start Dask client: {e}") - dask_client = None + backend_client = None # Run processors in order try: - for proc in processors: - if proc.use_dask and dask_client is not None: - proc.dask_client = dask_client - logger.info('=> Running processor "%s" with Dask', proc) - else: - logger.info('=> Running processor "%s" with Multiprocessing', proc) - proc.process() + if global_use_backend == "curator": + pipeline = Pipeline(name="processing", description="Process data from JSONL files") + for p in cfg.processors: + stage = hydra.utils.instantiate(p, use_backend="curator", backend_client=backend_client) + pipeline.add_stage(stage) + + executor = XennaExecutor() + pipeline.run(executor) + else: + t = EmptyTask + for proc in processors: + if proc.use_backend == "dask" and backend_client is not None: + proc.backend_client = backend_client + logger.info('=> Running processor "%s" with Dask', proc) + else: + logger.info('=> Running processor "%s" with Multiprocessing', proc) + t = proc.process(t) finally: - if dask_client is not None: + if backend_client is not None: logger.info("Shutting down Dask client...") - dask_client.close(timeout="60s") + backend_client.close(timeout="60s") logger.info("Dask client shutdown complete") -#tmp_dir is removed here after all processing finishes. !!! + +# tmp_dir is removed here after all processing finishes. !!! diff --git a/tests/test_curator.py b/tests/test_curator.py new file mode 100644 index 00000000..2b0685cb --- /dev/null +++ b/tests/test_curator.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from pathlib import Path + +import yaml +from omegaconf import OmegaConf +from ray_curator.stages.base import ProcessingStage +from ray_curator.tasks import Task, _EmptyTask + +from sdp.run_processors import run_processors +from sdp.utils.common import load_manifest, save_manifest + + +def _write_config(file_path: Path, dict_conf): + with file_path.open("w") as file: + yaml.dump(dict_conf, file) + + +def _make_dict(output_manifest_file, use_backend=None): + workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") + return { + "processors_to_run": "0:", + "use_backend": use_backend, + "processors": [ + { + "_target_": "sdp.processors.CreateInitialManifestByExt", + "raw_data_dir": workspace_dir, + "extension": "mp3", + "output_file_key": "audio_filepath", + "output_manifest_file": output_manifest_file, + }, + ], + } + + +def _make_expected_output(): + workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") + return [{'audio_filepath': os.path.join(workspace_dir, "Eleonora/Eleonora30s.mp3")}] + + +def test_curator(): + tmpdir = tempfile.TemporaryDirectory() + output_path = os.path.join(tmpdir.name, "output_manifest_file.jsonl") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend="curator") + conf_path = Path(tmpdir.name) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + + output = load_manifest(output_path) + tmpdir.cleanup() + expected_output = _make_expected_output() + assert output == expected_output, f"Expected {expected_output}, but got {output}" + + +def test_multiprocessing(): + tmpdir = tempfile.TemporaryDirectory() + output_path = os.path.join(tmpdir.name, "output_manifest_file.jsonl") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend=None) + conf_path = Path(tmpdir.name) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + + output = load_manifest(output_path) + tmpdir.cleanup() + expected_output = _make_expected_output() + assert output == expected_output, f"Expected {expected_output}, but got {output}" diff --git a/tests/test_manifest_chunking.py b/tests/test_manifest_chunking.py index ae1aa394..f836c6df 100644 --- a/tests/test_manifest_chunking.py +++ b/tests/test_manifest_chunking.py @@ -21,97 +21,93 @@ import json import pytest +from ray_curator.tasks import _EmptyTask -from sdp.processors import DropNonAlphabet -from sdp.processors import SubMakeLowercase +from sdp.processors import DropNonAlphabet, SubMakeLowercase -def test_submakelowercase_with_chunking(tmp_path): - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "abc"}, - {"text": "def"}, - {"text": "ghi"}, - {"text": "jkl"}, - {"text": "mno"}, - {"text": "pqr"}, - {"text": "stu"}, - {"text": "vwx"}, - {"text": "yz"}, - ] - - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = SubMakeLowercase( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2 - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines +def test_submakelowercase_with_chunking(tmp_path): + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "abc"}, + {"text": "def"}, + {"text": "ghi"}, + {"text": "jkl"}, + {"text": "mno"}, + {"text": "pqr"}, + {"text": "stu"}, + {"text": "vwx"}, + {"text": "yz"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = SubMakeLowercase( + input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, in_memory_chunksize=2 + ) + + processor.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines def test_dropnonalphabet_with_chunking(tmp_path): - - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "ABC"}, - ] - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = DropNonAlphabet( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2, - alphabet="ABC" - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "ABC"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = DropNonAlphabet( + input_manifest_file=input_manifest_file, + output_manifest_file=output_manifest_file, + in_memory_chunksize=2, + alphabet="ABC", + ) + + processor.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines diff --git a/tests/test_modify_manifest.py b/tests/test_modify_manifest.py index 99583c26..3542b54e 100644 --- a/tests/test_modify_manifest.py +++ b/tests/test_modify_manifest.py @@ -18,6 +18,7 @@ from typing import Dict, List, Union import pytest +from ray_curator.tasks import EmptyTask from sdp.processors import ApplyInnerJoin, DropNonAlphabet @@ -161,7 +162,7 @@ def test_apply_inner_join( output_manifest_file=manifest_out, ) - processor.process() + processor.process(EmptyTask) with open(manifest_out, "r") as f: output_lines = [json.loads(line) for line in f] diff --git a/tests/test_tts_sdp_end_to_end.py b/tests/test_tts_sdp_end_to_end.py index c3517e95..c593bf77 100644 --- a/tests/test_tts_sdp_end_to_end.py +++ b/tests/test_tts_sdp_end_to_end.py @@ -1,37 +1,38 @@ -import pytest -import boto3 import json import os import tarfile from pathlib import Path + +import boto3 +import pytest from omegaconf import OmegaConf + from sdp.run_processors import run_processors from sdp.utils.common import load_manifest DATASET_CONFIGS_ROOT = Path(__file__).parents[1] / "dataset_configs" + @pytest.fixture def get_tts_ytc_data(tmpdir: str): # Download the data from S3 s3 = boto3.client( - 's3', - aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), - aws_secret_access_key=os.getenv("AWS_SECRET_KEY") + 's3', aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), aws_secret_access_key=os.getenv("AWS_SECRET_KEY") ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/test_data_reference.json", - tmpdir/"test_data_reference.json", + "sdp-test-data", + "test_data/tts/ytc/test_data_reference.json", + tmpdir / "test_data_reference.json", ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/ytc.en.tar.gz", - tmpdir/"ytc.en.tar.gz", + "sdp-test-data", + "test_data/tts/ytc/ytc.en.tar.gz", + tmpdir / "ytc.en.tar.gz", ) # Extract the tar.gz file - with tarfile.open(tmpdir/"ytc.en.tar.gz", "r:gz") as tar: + with tarfile.open(tmpdir / "ytc.en.tar.gz", "r:gz") as tar: tar.extractall(tmpdir) audio_files = Path(tmpdir).glob("audios/*") @@ -45,6 +46,7 @@ def get_tts_ytc_data(tmpdir: str): return tmpdir + def test_tts_sdp_end_to_end(get_tts_ytc_data): data_dir = get_tts_ytc_data assert os.path.exists(data_dir) @@ -71,13 +73,13 @@ def test_tts_sdp_end_to_end(get_tts_ytc_data): output_data = load_manifest(cfg.final_manifest, encoding="utf8") for item in output_data: output_file_data[item["audio_item_id"]] = item - + reference_file_data = {} reference_data = load_manifest(reference_manifest_file, encoding="utf8") for item in reference_data: reference_file_data[item["audio_item_id"]] = item - + assert len(output_file_data) == len(reference_file_data) assert len(output_file_data) == 2 for audio_item_id in output_file_data: - assert output_file_data[audio_item_id]["segments"] == reference_file_data[audio_item_id]["segments"] \ No newline at end of file + assert output_file_data[audio_item_id]["segments"] == reference_file_data[audio_item_id]["segments"]