diff --git a/docs/gen_docs.py b/docs/gen_docs.py index 856b1948..edce3ea3 100644 --- a/docs/gen_docs.py +++ b/docs/gen_docs.py @@ -53,5 +53,6 @@ def gen_docs(): with open(destination_path, "wt", encoding="utf-8") as fout: fout.write(docs + link) + if __name__ == '__main__': gen_docs() diff --git a/docs/src/conf.py b/docs/src/conf.py index 11964c06..3eea6b15 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -49,8 +49,8 @@ "webvtt_py", "python_docx", "webvtt", - "docx", - "pyannote" + "docx", + "pyannote", ] _skipped_autodoc_mock_imports = [] @@ -189,9 +189,9 @@ def setup(app): ] # nitpick_ignore_regex = [('py:class', '*')] -#adding this especially for coraal, temporary +# adding this especially for coraal, temporary linkcheck_ignore = [ r'https://lingtools\.uoregon\.edu/coraal/coraal_download_list\.txt', - r'https://ieeexplore\.ieee\.org/document/1326009' + r'https://ieeexplore\.ieee\.org/document/1326009', ] -# https://lingtools.uoregon.edu/coraal/coraal_download_list.txt \ No newline at end of file +# https://lingtools.uoregon.edu/coraal/coraal_download_list.txt diff --git a/main.py b/main.py index b9116dc3..a69feca7 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ # limitations under the License. import sys + import hydra from omegaconf import DictConfig, open_dict @@ -23,20 +24,20 @@ def main(cfg: DictConfig): """ Main entry point for the Speech Data Processor (SDP). - + Args: cfg: Hydra configuration object containing processing settings """ # Check if running in import manager mode if hasattr(cfg, 'mode') and cfg.mode == 'update_imports': update_processor_imports(cfg.config_path) - + # Check arg for using Dask if not hasattr(cfg, 'use_dask'): with open_dict(cfg): # Default to using Dask cfg.use_dask = True - + # Run the processors run_processors(cfg) diff --git a/requirements/huggingface.txt b/requirements/huggingface.txt index e603c631..f1f39bcf 100644 --- a/requirements/huggingface.txt +++ b/requirements/huggingface.txt @@ -1,3 +1,3 @@ accelerate -transformers>=0.2.1 huggingface_hub>=0.20.3,<0.24.0 # https://github.com/NVIDIA/NeMo/issues/9793 +transformers>=0.2.1 diff --git a/requirements/main.txt b/requirements/main.txt index 99c030b4..87df1a99 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -1,26 +1,26 @@ +dask +datasets>=2.14.0,<3.0.0 diff_match_patch +distributed editdistance ffmpeg +gdown hydra-core +jiwer>=3.1.0,<4.0.0 joblib librosa>=0.10.0 # specify >=0.10.0 so that librosa.get_duration(path=...) will work numpy>=1.26, <2.0 # module was used numpy 1.x and may crash in 2.x omegaconf pandas +pyarrow>=8.0.0,<14.0.0 +pydub +python-docx rarfile regex sox tqdm -gdown webvtt-py wget -python-docx -pydub -dask -distributed -jiwer>=3.1.0,<4.0.0 -pyarrow>=8.0.0,<14.0.0 -datasets>=2.14.0,<3.0.0 # toloka-kit # Temporarily disabled due to Toloka's technical pause; keep as reference for past and future API support # for some processers, additionally https://github.com/NVIDIA/NeMo is required # for some processers, additionally nemo_text_processing is required diff --git a/requirements/tests.txt b/requirements/tests.txt index 0f8b8675..3474c8e8 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,8 +1,8 @@ boto3 +# lhotse requires torch and torchaudio to be present +lhotse # additional packages required to run tests pytest pytest-cov -# lhotse requires torch and torchaudio to be present -lhotse torch -torchaudio \ No newline at end of file +torchaudio diff --git a/requirements/tts.txt b/requirements/tts.txt index 7f24349c..26ba9b98 100644 --- a/requirements/tts.txt +++ b/requirements/tts.txt @@ -1,6 +1,6 @@ -transformers accelerate -torchaudio -pyannote-audio ffmpeg-python +pyannote-audio +torchaudio +transformers whisperx==3.3.1 diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index ab0d05ed..118524bf 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -22,11 +22,11 @@ TrainDevTestSplitCORAAL, ) from sdp.processors.datasets.earnings import ( - CreateInitialAudioAndManifest, + ApplyEarnings21Normalizations, CreateFullAudioManifestEarnings21, - SpeakerSegmentedManifest, + CreateInitialAudioAndManifest, CreateSentenceSegmentedManifest, - ApplyEarnings21Normalizations, + SpeakerSegmentedManifest, ) from sdp.processors.datasets.fleurs.create_initial_manifest import ( CreateInitialManifestFleurs, @@ -82,19 +82,21 @@ CreateInitialManifestHuggingFace, ) from sdp.processors.huggingface.speech_recognition import ASRTransformers +from sdp.processors.manage_files.convert_audio import FfmpegConvert, SoxConvert +from sdp.processors.manage_files.extract import ExtractTar +from sdp.processors.manage_files.remove import RemoveFiles from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, ChangeToRelativePath, CombineSources, + DropSpecifiedFields, DuplicateFields, KeepOnlySpecifiedFields, RenameFields, SortManifest, SplitOnFixedDuration, Subprocess, - DropSpecifiedFields, - ) from sdp.processors.modify_manifest.create_manifest import ( CreateCombinedManifests, @@ -109,6 +111,8 @@ GetWER, InsIfASRInsertion, InverseNormalizeText, + LambdaExpression, + ListToEntries, MakeSentence, NormalizeText, ReadDocxLines, @@ -117,8 +121,6 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, - ListToEntries, - LambdaExpression, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, @@ -141,16 +143,6 @@ from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, ) -from sdp.processors.manage_files.convert_audio import ( - FfmpegConvert, - SoxConvert, -) -from sdp.processors.manage_files.extract import ( - ExtractTar, -) -from sdp.processors.manage_files.remove import ( - RemoveFiles, -) from sdp.processors.nemo.asr_inference import ASRInference from sdp.processors.nemo.estimate_bandwidth import EstimateBandwidth from sdp.processors.nemo.lid_inference import AudioLid diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index a4257e53..3fb84414 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -58,7 +58,6 @@ class BaseProcessor(ABC): """ def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] = None, **kwargs): - if output_manifest_file and input_manifest_file and (output_manifest_file == input_manifest_file): # we cannot have the same input and output manifest file specified because we need to be able to # read from the input_manifest_file and write to the output_manifest_file at the same time @@ -82,6 +81,7 @@ def test(self): There are not tests by default. """ + class BaseParallelProcessor(BaseProcessor): """ A processor that performs per-entry processing in parallel (using Dask or multiprocessing). @@ -96,7 +96,7 @@ class BaseParallelProcessor(BaseProcessor): use_dask (bool): If True, use Dask for parallelization; otherwise, use multiprocessing. dask_client: (Optional) An existing Dask client. """ - + def __getstate__(self): state = self.__dict__.copy() # Remove the Dask client from state (it is not picklable) @@ -116,7 +116,7 @@ def __init__( dask_client=None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_dask", None) # super().__init__(input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, **kwargs) if max_workers == -1: max_workers = os.cpu_count() @@ -129,25 +129,21 @@ def __init__( self.test_cases = test_cases or [] self.use_dask = use_dask self.dask_client = dask_client - + def prepare(self): - """Can be used in derived classes to prepare the processing. - - """ + """Can be used in derived classes to prepare the processing.""" pass def process(self): - """A fork in the road to pick dask or classic processing - - """ + """A fork in the road to pick dask or classic processing""" os.environ.setdefault("PATH", os.defpath) self.prepare() - + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) metrics = [] - - #Ability to work sa legacy and as dask + + # Ability to work sa legacy and as dask if self.use_dask: self._process_with_dask(metrics) else: @@ -161,7 +157,8 @@ def _process_with_dask(self, metrics): if self.dask_client is None: self.dask_client = Client() client = self.dask_client - from sdp.logging import logger + from sdp.logging import logger + logger.info(f"Using Dask client with dashboard at: {client.dashboard_link}") # Delegate manifest reading to read_manifest() which returns a Dask bag. @@ -210,7 +207,7 @@ def _process_with_multiprocessing(self, metrics): def _chunk_manifest(self): """Splits the input manifest into chunks of in_memory_chunksize size. - Only used in non-Dask (multiprocessing) mode. + Only used in non-Dask (multiprocessing) mode. """ manifest_chunk = [] # When use_dask is False, read_manifest() returns an iterator. @@ -225,7 +222,7 @@ def _chunk_manifest(self): def read_manifest(self): """ Reads entries from the input manifest. - + Behavior depends on the parallelization mode: - When use_dask is True: If the input_manifest_file exists and is non-empty, returns a Dask bag (reading in 256KB blocks). @@ -233,30 +230,41 @@ def read_manifest(self): - When use_dask is False: If the input_manifest_file does not exist or is empty, logs the condition and returns an empty iterator. Otherwise, opens the file in text mode, strips each line, and yields the parsed JSON from non-empty lines. - + This unified behavior lets the processor run even in manifest-creation mode. """ - from sdp.logging import logger + from sdp.logging import logger + if self.use_dask: import dask.bag as db - if self.input_manifest_file and os.path.exists(self.input_manifest_file) and os.path.getsize(self.input_manifest_file) > 0: + + if ( + self.input_manifest_file + and os.path.exists(self.input_manifest_file) + and os.path.getsize(self.input_manifest_file) > 0 + ): bag = db.read_text(self.input_manifest_file, blocksize=2**18).map(json.loads) return bag else: - logger.info("No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation.") + logger.info( + "No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation." + ) return db.from_sequence([]) else: if not self.input_manifest_file or not os.path.exists(self.input_manifest_file): - logger.info("No input manifest file provided or file does not exist. Continuing with an empty manifest.") + logger.info( + "No input manifest file provided or file does not exist. Continuing with an empty manifest." + ) return iter([]) - else: - #if use_dask = False, we get here - def generator(): #Reading manifest line by line, adding only non emply lines + else: + # if use_dask = False, we get here + def generator(): # Reading manifest line by line, adding only non emply lines with open(self.input_manifest_file, "rt", encoding="utf8") as fin: for line in fin: - if line: - yield json.loads(line) + if line: + yield json.loads(line) + return generator() @abstractmethod @@ -270,38 +278,43 @@ def process_dataset_entry(self, data_entry) -> List[Any]: def finalize(self, metrics: List[Any]): """Outputs metrics about the processed data.""" from sdp.logging import logger + logger.info("Total number of entries after processing: %d", self.number_of_entries) if self.total_duration: logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Processor completed in (seconds): %.2f", elapsed) def test(self): - """Applies processing to each test case and raises an error if the output does not match expected output.""" + """Applies processing to each test case and raises an error if the output does not match expected output.""" for test_case in self.test_cases: input_data = test_case["input"].copy() if isinstance(test_case["input"], dict) else test_case["input"] generated_outputs = self.process_dataset_entry(input_data) - expected_outputs = [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + expected_outputs = ( + [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + ) for gen_out, exp_out in zip(generated_outputs, expected_outputs): gen_data = gen_out.data if hasattr(gen_out, "data") else gen_out if gen_data != exp_out: raise RuntimeError( - "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}" - .format(test_case["input"], gen_data, exp_out) + "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}".format( + test_case["input"], gen_data, exp_out + ) ) - # ------------------ Legacy Parallel Processor ------------------ #Just for reference class LegacyParallelProcessor(BaseProcessor): """ A legacy parallel processor implementation using multiprocessing and process_map. - + This class processes the manifest in chunks (using process_map) and is provided for compatibility. Child classes must implement process_dataset_entry(). - + Args: max_workers (int): maximum number of workers that will be spawned during the parallel processing. @@ -312,12 +325,13 @@ class LegacyParallelProcessor(BaseProcessor): test_cases (list[dict]): an optional list of dicts containing test cases for checking that the processor makes the changes that we are expecting. - + The dicts must have a key ``input``, the value of which is a dictionary containing data which is our test's input manifest line, and a key ``output``, the value of which is a dictionary containing data which is the expected output manifest line. """ + def __init__( self, max_workers: int = -1, @@ -326,7 +340,7 @@ def __init__( test_cases: Optional[List[Dict]] = None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_dask", None) # super().__init__(**kwargs) if max_workers == -1: max_workers = multiprocessing.cpu_count() @@ -478,9 +492,12 @@ def finalize(self, metrics): if self.total_duration: logger.info("Total audio duration (hours) after processing (legacy): %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Legacy processor completed in (seconds): %.2f", elapsed) + def test(self): """Applies processing to "test_cases" and raises an error in case of mismatch.""" for test_case in self.test_cases: @@ -498,4 +515,4 @@ def test(self): f"Test input: {test_case['input']}\n" f"Generated output: {generated_output}\n" f"Expected output: {expected_output}" - ) \ No newline at end of file + ) diff --git a/sdp/processors/datasets/commoncrawl/commoncrawl.py b/sdp/processors/datasets/commoncrawl/commoncrawl.py index 8a5cc2c6..f9eed63c 100644 --- a/sdp/processors/datasets/commoncrawl/commoncrawl.py +++ b/sdp/processors/datasets/commoncrawl/commoncrawl.py @@ -2,38 +2,38 @@ from typing import List import soundfile as sf + from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.processors.datasets.commoncrawl.harv_utils import split_by_vtt - class SplitByVttSentence(BaseParallelProcessor): """ - A class for splitting audio files based on VTT (WebVTT) sentence-level segmentation in a dataset. + A class for splitting audio files based on VTT (WebVTT) sentence-level segmentation in a dataset. - Args: - splited_audio_dir (str): The directory to store the split audio files. - source_audio_key (str): The field in the dataset containing the path to the source audio files. - target_audio_key (str): The field to store the paths of the split audio files. - duration_key (str): The field to store the duration of each split audio segment. - text_key (str): The field to store the transcriptions corresponding to each split audio segment. - caption_file_key (str): The field in the dataset containing the path to the VTT (WebVTT) files for segmentation. - additional_fields (List[str], optional): List of additional fields to copy from the original data entry to the split entries. - Defaults to an empty list. - duration_threshold (float, optional): The duration threshold in seconds for each split audio segment. Defaults to 10.0. + Args: + splited_audio_dir (str): The directory to store the split audio files. + source_audio_key (str): The field in the dataset containing the path to the source audio files. + target_audio_key (str): The field to store the paths of the split audio files. + duration_key (str): The field to store the duration of each split audio segment. + text_key (str): The field to store the transcriptions corresponding to each split audio segment. + caption_file_key (str): The field in the dataset containing the path to the VTT (WebVTT) files for segmentation. + additional_fields (List[str], optional): List of additional fields to copy from the original data entry to the split entries. + Defaults to an empty list. + duration_threshold (float, optional): The duration threshold in seconds for each split audio segment. Defaults to 10.0. """ def __init__( - self, - splited_audio_dir: str, - source_audio_field: str, - target_audio_field: str, - duration_field: str, - text_field: str, - vtt_field: str, - additional_fields: List[str] = [], - duration_threshold: float = 10.0, - **kwargs, + self, + splited_audio_dir: str, + source_audio_field: str, + target_audio_field: str, + duration_field: str, + text_field: str, + vtt_field: str, + additional_fields: List[str] = [], + duration_threshold: float = 10.0, + **kwargs, ): super().__init__(**kwargs) self.splited_audio_dir = splited_audio_dir @@ -67,10 +67,13 @@ def process_dataset_entry(self, data_entry): pass end_c = end_sr if len(text_c) > 0 and ( - end_c - start_c > self.duration_threshold * samplerate or - text_c[-1] == "." or text_c[-1] == "?"): + end_c - start_c > self.duration_threshold * samplerate + or text_c[-1] == "." + or text_c[-1] == "?" + ): res_list.append( - self.makeDataEntry(data_entry, data, vtt_file, samplerate, text_c, start_c, end_c)) + self.makeDataEntry(data_entry, data, vtt_file, samplerate, text_c, start_c, end_c) + ) text_c = '' start_c, end_c = 0, 0 else: @@ -82,18 +85,20 @@ def process_dataset_entry(self, data_entry): def makeDataEntry(self, data_entry, data, vtt_file, samplerate, text_c, start_c, end_c): data_sample = data[start_c:end_c] - wav_save_file = os.path.join(self.splited_audio_dir, '/'.join(os.path.splitext(vtt_file)[0].split('/')[-2:]), - str(int(start_c / (samplerate / 1000))) + "-" + str( - int(end_c / (samplerate / 1000))) + ".wav") + wav_save_file = os.path.join( + self.splited_audio_dir, + '/'.join(os.path.splitext(vtt_file)[0].split('/')[-2:]), + str(int(start_c / (samplerate / 1000))) + "-" + str(int(end_c / (samplerate / 1000))) + ".wav", + ) if not os.path.isfile(wav_save_file): os.makedirs(os.path.split(wav_save_file)[0], exist_ok=True) sf.write(wav_save_file, data_sample, samplerate) - data = {self.target_audio_field: wav_save_file, - self.duration_field: data_sample.shape[0] / samplerate, - self.text_field: text_c.strip(), - } + data = { + self.target_audio_field: wav_save_file, + self.duration_field: data_sample.shape[0] / samplerate, + self.text_field: text_c.strip(), + } for field in self.additional_fields: data[field] = data_entry[field] return DataEntry(data=data) - diff --git a/sdp/processors/datasets/commoncrawl/harv_utils.py b/sdp/processors/datasets/commoncrawl/harv_utils.py index 24efe80e..5b1ae48e 100644 --- a/sdp/processors/datasets/commoncrawl/harv_utils.py +++ b/sdp/processors/datasets/commoncrawl/harv_utils.py @@ -1,5 +1,7 @@ -import webvtt # pip install webvtt-py from datetime import datetime + +import webvtt # pip install webvtt-py + from sdp.logging import logger @@ -42,4 +44,3 @@ def split_by_vtt(vtt_file, samplerate): except Exception as e: logger.warning(str(e) + vtt_file) return None, None, None - diff --git a/sdp/processors/datasets/coraa/create_initial_manifest.py b/sdp/processors/datasets/coraa/create_initial_manifest.py index 5be1a8ab..7a348b70 100644 --- a/sdp/processors/datasets/coraa/create_initial_manifest.py +++ b/sdp/processors/datasets/coraa/create_initial_manifest.py @@ -2,49 +2,51 @@ import os from pathlib import Path from typing import List -import pandas as pd -import rarfile #Needs to be installed +import pandas as pd +import rarfile # Needs to be installed import sox from sox import Transformer from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.utils.common import extract_archive + class CreateInitialManifestCORAA(BaseParallelProcessor): """ - Processor to create initial manifest file fo CORAA ASR dataset - - Dataset link: https://github.com/nilc-nlp/CORAA - - Args: - raw_data_dir (str): the path to the directory in which all the data will be downloaded. - extract_archive_dir (str): directory where the extracted data will be saved. - data_split (str): "train", "dev" or "test". - resampled_audio_dir (str): the directory where the resampled wav files will be stored. - already_extracted (bool): if True, we will not try to extract the raw data. - Defaults to False. - already_downloaded (bool): if True, we will not try to download files. - target_samplerate (int): sample rate (Hz) to use for resampling. This parameter will - Defaults to 16000. - target_nchannels (int): number of channels to create during resampling process. - Defaults to 1. - exclude_dataset: list: list of the dataset names that will be excluded when creating initial manifest. - Options 'SP2010', 'C-ORAL-BRASIL I', 'NURC-Recife', 'TEDx Talks', 'ALIP' + Processor to create initial manifest file fo CORAA ASR dataset + + Dataset link: https://github.com/nilc-nlp/CORAA + + Args: + raw_data_dir (str): the path to the directory in which all the data will be downloaded. + extract_archive_dir (str): directory where the extracted data will be saved. + data_split (str): "train", "dev" or "test". + resampled_audio_dir (str): the directory where the resampled wav files will be stored. + already_extracted (bool): if True, we will not try to extract the raw data. + Defaults to False. + already_downloaded (bool): if True, we will not try to download files. + target_samplerate (int): sample rate (Hz) to use for resampling. This parameter will + Defaults to 16000. + target_nchannels (int): number of channels to create during resampling process. + Defaults to 1. + exclude_dataset: list: list of the dataset names that will be excluded when creating initial manifest. + Options 'SP2010', 'C-ORAL-BRASIL I', 'NURC-Recife', 'TEDx Talks', 'ALIP' """ + def __init__( - self, - raw_data_dir: str, - extract_archive_dir: str, - data_split: str, - resampled_audio_dir: str, - already_extracted: bool = False, - already_downloaded: bool = False, - target_samplerate: int = 16000, - target_nchannels: int = 1, - exclude_dataset: list = [], - **kwargs, + self, + raw_data_dir: str, + extract_archive_dir: str, + data_split: str, + resampled_audio_dir: str, + already_extracted: bool = False, + already_downloaded: bool = False, + target_samplerate: int = 16000, + target_nchannels: int = 1, + exclude_dataset: list = [], + **kwargs, ): super().__init__(**kwargs) self.raw_data_dir = Path(raw_data_dir) @@ -65,13 +67,15 @@ def prepare(self): if not self.already_downloaded: try: from huggingface_hub import snapshot_download + snapshot_download(repo_id="gabrielrstan/CORAA-v1.1", repo_type='dataset', local_dir=self.raw_data_dir) except ImportError: - raise ImportError("huggingface_hub is required to download the dataset. Please install it with pip install huggingface_hub") + raise ImportError( + "huggingface_hub is required to download the dataset. Please install it with pip install huggingface_hub" + ) if not self.already_extracted: - if self.data_split == 'train': - first_rar_file = glob.glob(str(self.raw_data_dir) + "/train_dividido"+f"/*{self.data_split}*1.rar") + first_rar_file = glob.glob(str(self.raw_data_dir) + "/train_dividido" + f"/*{self.data_split}*1.rar") if first_rar_file and not isinstance(first_rar_file, str): first_rar_file = first_rar_file[0] @@ -79,8 +83,7 @@ def prepare(self): rar = rarfile.RarFile(first_rar_file) rar.extractall(path=self.extract_archive_dir) else: - - zip_files =glob.glob(str(self.raw_data_dir) + f"/*{self.data_split}.zip") + zip_files = glob.glob(str(self.raw_data_dir) + f"/*{self.data_split}.zip") if not zip_files: raise RuntimeError( f"Did not find any file matching {self.raw_data_dir}/*.zip. " @@ -97,12 +100,11 @@ def prepare(self): def read_manifest(self): self.df = pd.read_csv(self.transcription_file) - data_entries = self.df[~self.df['dataset'].isin(self.exclude_dataset)][['file_path','text']] + data_entries = self.df[~self.df['dataset'].isin(self.exclude_dataset)][['file_path', 'text']] res = [tuple(row[1]) for row in data_entries.iterrows()] return res def process_dataset_entry(self, data_entry) -> List[DataEntry]: - file_path, text = data_entry file_name = os.path.splitext(os.path.basename(file_path))[0] transcript_text = text.strip() diff --git a/sdp/processors/datasets/earnings/__init__.py b/sdp/processors/datasets/earnings/__init__.py index d71e41d8..10cb9c07 100644 --- a/sdp/processors/datasets/earnings/__init__.py +++ b/sdp/processors/datasets/earnings/__init__.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from sdp.processors.datasets.earnings.apply_normalizations import ( + ApplyEarnings21Normalizations, +) from sdp.processors.datasets.earnings.create_initial_manifest import ( - CreateInitialAudioAndManifest, CreateFullAudioManifestEarnings21, - SpeakerSegmentedManifest, + CreateInitialAudioAndManifest, CreateSentenceSegmentedManifest, NeMoForcedAligner, + SpeakerSegmentedManifest, ) -from sdp.processors.datasets.earnings.apply_normalizations import ( - ApplyEarnings21Normalizations, -) \ No newline at end of file diff --git a/sdp/processors/datasets/earnings/apply_normalizations.py b/sdp/processors/datasets/earnings/apply_normalizations.py index b81f245c..6959a89d 100644 --- a/sdp/processors/datasets/earnings/apply_normalizations.py +++ b/sdp/processors/datasets/earnings/apply_normalizations.py @@ -14,7 +14,7 @@ import json from pathlib import Path -from typing import Dict, List, Any +from typing import Any, Dict, List from sdp.processors.base_processor import BaseProcessor, DataEntry @@ -44,7 +44,7 @@ class ApplyEarnings21Normalizations(BaseProcessor): fallback_to_original: true preserve_entity_tags: true """ - + def __init__( self, earnings21_root: str, @@ -58,46 +58,46 @@ def __init__( self.use_top_candidate = use_top_candidate self.fallback_to_original = fallback_to_original self.preserve_entity_tags = preserve_entity_tags - + def process_dataset_entry(self, data_entry: DataEntry) -> List[DataEntry]: """Process a single dataset entry to apply normalizations.""" data = data_entry.data - + # Extract file_id to load corresponding normalization file file_id = data.get('file_id') if not file_id: # If no file_id, return original entry return [data_entry] - + # Load normalization data for this file norm_file = self.earnings21_root / "transcripts" / "normalizations" / f"{file_id}.norm.json" - + if not norm_file.exists(): # If no normalization file, return original entry return [data_entry] - + try: with open(norm_file, 'r', encoding='utf-8') as f: normalizations = json.load(f) except (json.JSONDecodeError, FileNotFoundError): # If can't load normalization file, return original entry return [data_entry] - + # Apply normalizations to text normalized_text = self._apply_normalizations(data.get('text', ''), normalizations) - + # Create new data entry with normalized text new_data = data.copy() new_data['text'] = normalized_text - + return [DataEntry(data=new_data)] - + def _apply_normalizations(self, text: str, normalizations: Dict[str, Any]) -> str: """Apply normalizations to text based on normalization data.""" # This is a simplified implementation # In practice, you would need to map tokens to normalization IDs # and apply the appropriate normalizations - + # For now, just return the original text # This can be extended to implement actual normalization logic - return text \ No newline at end of file + return text diff --git a/sdp/processors/datasets/earnings/create_initial_manifest.py b/sdp/processors/datasets/earnings/create_initial_manifest.py index fa10ca45..dae9e320 100644 --- a/sdp/processors/datasets/earnings/create_initial_manifest.py +++ b/sdp/processors/datasets/earnings/create_initial_manifest.py @@ -17,13 +17,17 @@ import os import re from pathlib import Path -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional -import pandas as pd import librosa +import pandas as pd from sdp.logging import logger -from sdp.processors.base_processor import BaseParallelProcessor, BaseProcessor, DataEntry +from sdp.processors.base_processor import ( + BaseParallelProcessor, + BaseProcessor, + DataEntry, +) from sdp.utils.common import extract_archive @@ -32,7 +36,7 @@ class CreateInitialAudioAndManifest(BaseParallelProcessor): """Create initial audio manifest from Earnings21/22 dataset files. This processor creates the initial manifest for Earnings21/22 datasets by discovering - audio files and creating manifest entries with duration information. Audio format + audio files and creating manifest entries with duration information. Audio format conversion should be handled by a separate FfmpegConvert processor in the pipeline. Args: @@ -94,7 +98,7 @@ def prepare(self): metadata_file = self.dataset_root / "earnings21-file-metadata.csv" else: # earnings22 metadata_file = self.dataset_root / "metadata.csv" - + # If metadata file doesn't exist, discover files from audio directory if not metadata_file.exists(): logger.warning(f"Metadata file not found: {metadata_file}. Discovering files from audio directory.") @@ -111,13 +115,11 @@ def prepare(self): self.file_ids = file_metadata_df['File ID'].astype(str).tolist() else: raise ValueError(f"Neither 'file_id' nor 'File ID' column found in {metadata_file}") - + if self.test_mode: self.file_ids = self.file_ids[:2] - - logger.info(f"Loaded {len(self.file_ids)} file IDs for {self.dataset_type} subset {self.subset}.") - + logger.info(f"Loaded {len(self.file_ids)} file IDs for {self.dataset_type} subset {self.subset}.") def read_manifest(self): """Read and process all files to create manifest entries.""" @@ -126,7 +128,7 @@ def read_manifest(self): def process_dataset_entry(self, file_id: str) -> List[DataEntry]: """Process a single file to create full audio manifest entry.""" file_id = str(file_id) - + # Find audio file audio_file = None for ext in ['.mp3', '.wav']: @@ -134,7 +136,7 @@ def process_dataset_entry(self, file_id: str) -> List[DataEntry]: if potential_path.exists(): audio_file = potential_path break - + if not audio_file: logger.warning(f"Audio file not found for {file_id}") return [] @@ -142,7 +144,7 @@ def process_dataset_entry(self, file_id: str) -> List[DataEntry]: try: # Get audio duration from the original audio file duration = librosa.get_duration(path=str(audio_file)) - + # Create manifest entry entry_data = { "audio_filepath": str(audio_file), @@ -150,9 +152,9 @@ def process_dataset_entry(self, file_id: str) -> List[DataEntry]: "text": "", # Placeholder text "file_id": file_id, } - + return [DataEntry(data=entry_data)] - + except Exception as e: logger.error(f"Error processing audio file {file_id}: {e}") return [] @@ -215,8 +217,10 @@ def _get_nlp_file_path(self, file_id: str) -> Path: else: # earnings22 # Check both possible locations for earnings22 nlp_path1 = self.dataset_root / "transcripts" / "nlp_references" / f"{file_id}.nlp" - nlp_path2 = self.dataset_root / "subset10" / "nonverbatim_transcripts" / "nlp_references" / f"{file_id}.nlp" - + nlp_path2 = ( + self.dataset_root / "subset10" / "nonverbatim_transcripts" / "nlp_references" / f"{file_id}.nlp" + ) + if nlp_path1.exists(): return nlp_path1 elif nlp_path2.exists(): @@ -227,11 +231,11 @@ def _get_nlp_file_path(self, file_id: str) -> Path: def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: """Load NLP file containing tokens and metadata.""" nlp_file = self._get_nlp_file_path(file_id) - + if not nlp_file.exists(): logger.warning(f"NLP file not found: {nlp_file}") return [] - + tokens_list = [] try: with open(nlp_file, 'r', encoding='utf-8') as f: @@ -245,7 +249,7 @@ def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: for i, row_values in enumerate(reader): if len(row_values) == len(header): token_data = dict(zip(header, row_values)) - + # Parse 'tags' and 'wer_tags' fields if they are string representations of lists for key_to_parse in ['tags', 'wer_tags']: if key_to_parse in token_data: @@ -255,12 +259,14 @@ def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: token_data[key_to_parse] = json.loads(field_value) except json.JSONDecodeError: if field_value and field_value != "[]": - logger.debug(f"Field '{key_to_parse}' in {nlp_file} non-JSON: {field_value}") + logger.debug( + f"Field '{key_to_parse}' in {nlp_file} non-JSON: {field_value}" + ) tokens_list.append(token_data) else: logger.warning(f"Skipping malformed row in {nlp_file} (row {i+2})") return tokens_list - + except Exception as e: logger.error(f"Error processing NLP file {nlp_file}: {e}") return [] @@ -269,13 +275,13 @@ def _reconstruct_text(self, tokens: List[Dict[str, Any]]) -> str: """Reconstruct text from tokens with proper spacing and punctuation.""" if not tokens: return "" - + text_parts = [] for token in tokens: token_text = token.get('token', '').strip() if not token_text: continue - + text_parts.append(token_text) # Add punctuation if preserving and it exists if self.preserve_punctuation and token.get('punctuation'): @@ -289,7 +295,7 @@ def _reconstruct_text(self, tokens: List[Dict[str, Any]]) -> str: if not self.preserve_capitalization: text = text.lower() - + # Final cleanup of multiple spaces text = re.sub(r'\s+', ' ', text).strip() return text @@ -298,13 +304,13 @@ def process_dataset_entry(self, data_entry: Dict[str, Any]) -> List[DataEntry]: """Process a single manifest entry to add full text.""" file_id = data_entry['file_id'] tokens = self._load_nlp_file(file_id) - + if not tokens: logger.warning(f"No NLP tokens for {file_id}, text will be empty.") data_entry['text'] = data_entry.get('text', '') else: data_entry['text'] = self._reconstruct_text(tokens) - + return [DataEntry(data=data_entry)] @@ -328,7 +334,7 @@ class SpeakerSegmentedManifest(BaseParallelProcessor): use_speaker_metadata_csv (bool): Whether to use speaker metadata CSV for name mapping. Defaults to False. Returns: - Manifest entries segmented by speaker with audio_filepath, duration (set to 0), + Manifest entries segmented by speaker with audio_filepath, duration (set to 0), text, file_id, segment_id, and optionally speaker and tags fields. Example: @@ -380,7 +386,7 @@ def _load_speaker_metadata(self): if not metadata_file.exists(): logger.warning(f"Speaker metadata file not found: {metadata_file}") return - + try: df = pd.read_csv(metadata_file) for _, row in df.iterrows(): @@ -400,8 +406,10 @@ def _get_nlp_file_path(self, file_id: str) -> Path: else: # earnings22 # Check both possible locations for earnings22 nlp_path1 = self.dataset_root / "transcripts" / "nlp_references" / f"{file_id}.nlp" - nlp_path2 = self.dataset_root / "subset10" / "nonverbatim_transcripts" / "nlp_references" / f"{file_id}.nlp" - + nlp_path2 = ( + self.dataset_root / "subset10" / "nonverbatim_transcripts" / "nlp_references" / f"{file_id}.nlp" + ) + if nlp_path1.exists(): return nlp_path1 elif nlp_path2.exists(): @@ -412,11 +420,11 @@ def _get_nlp_file_path(self, file_id: str) -> Path: def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: """Load NLP file containing tokens and metadata.""" nlp_file = self._get_nlp_file_path(file_id) - + if not nlp_file.exists(): logger.warning(f"NLP file not found: {nlp_file}") return [] - + tokens_list = [] try: with open(nlp_file, 'r', encoding='utf-8') as f: @@ -430,7 +438,7 @@ def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: for i, row_values in enumerate(reader): if len(row_values) == len(header): token_data = dict(zip(header, row_values)) - + # Parse 'tags' and 'wer_tags' fields if they are string representations of lists for key_to_parse in ['tags', 'wer_tags']: if key_to_parse in token_data: @@ -440,12 +448,14 @@ def _load_nlp_file(self, file_id: str) -> List[Dict[str, Any]]: token_data[key_to_parse] = json.loads(field_value) except json.JSONDecodeError: if field_value and field_value != "[]": - logger.debug(f"Field '{key_to_parse}' in {nlp_file} non-JSON: {field_value}") + logger.debug( + f"Field '{key_to_parse}' in {nlp_file} non-JSON: {field_value}" + ) tokens_list.append(token_data) else: logger.warning(f"Skipping malformed row in {nlp_file} (row {i+2})") return tokens_list - + except Exception as e: logger.error(f"Error processing NLP file {nlp_file}: {e}") return [] @@ -454,11 +464,11 @@ def _load_entity_tags(self, file_id: str) -> Dict[str, Dict[str, str]]: """Load entity tags file (earnings21 only).""" if self.dataset_type != "earnings21": return {} - + tags_file = self.dataset_root / "transcripts" / "tags" / f"{file_id}.tags.json" if not tags_file.exists(): return {} - + try: with open(tags_file, 'r', encoding='utf-8') as f: return json.load(f) @@ -470,13 +480,13 @@ def _reconstruct_text(self, tokens: List[Dict[str, Any]]) -> str: """Reconstruct text from tokens with proper spacing and punctuation.""" if not tokens: return "" - + text_parts = [] for token in tokens: token_text = token.get('token', '').strip() if not token_text: continue - + text_parts.append(token_text) # Add punctuation if preserving and it exists if self.preserve_punctuation and token.get('punctuation'): @@ -490,7 +500,7 @@ def _reconstruct_text(self, tokens: List[Dict[str, Any]]) -> str: if not self.preserve_capitalization: text = text.lower() - + # Final cleanup of multiple spaces text = re.sub(r'\s+', ' ', text).strip() return text @@ -499,43 +509,47 @@ def _create_segments(self, tokens: List[Dict[str, Any]], file_id: str) -> List[D """Create segments based on speaker changes.""" if not tokens: return [] - + segments = [] current_segment_tokens = [] current_speaker_id = tokens[0].get('speaker', 'unknown_speaker_0') if tokens else 'unknown_speaker_0' for token in tokens: token_speaker_id = token.get('speaker', current_speaker_id) - + # Check for speaker change if token_speaker_id != current_speaker_id and current_segment_tokens: # Finalize current segment segment_text = self._reconstruct_text(current_segment_tokens) if segment_text.strip(): - segments.append({ - 'tokens': current_segment_tokens, - 'text': segment_text, - 'speaker_id': current_speaker_id, - 'file_id': file_id, - }) - + segments.append( + { + 'tokens': current_segment_tokens, + 'text': segment_text, + 'speaker_id': current_speaker_id, + 'file_id': file_id, + } + ) + # Start new segment current_segment_tokens = [token] current_speaker_id = token_speaker_id else: current_segment_tokens.append(token) - + # Handle last segment if current_segment_tokens: segment_text = self._reconstruct_text(current_segment_tokens) if segment_text.strip(): - segments.append({ - 'tokens': current_segment_tokens, - 'text': segment_text, - 'speaker_id': current_speaker_id, - 'file_id': file_id, - }) - + segments.append( + { + 'tokens': current_segment_tokens, + 'text': segment_text, + 'speaker_id': current_speaker_id, + 'file_id': file_id, + } + ) + return segments def process_dataset_entry(self, full_audio_manifest_entry: Dict[str, Any]) -> List[DataEntry]: @@ -563,7 +577,7 @@ def process_dataset_entry(self, full_audio_manifest_entry: Dict[str, Any]) -> Li for idx, segment_dict in enumerate(segments): segment_text = segment_dict['text'] speaker_id = segment_dict['speaker_id'] - + # Create manifest entry manifest_entry_data = { "audio_filepath": audio_filepath, # Point to original audio file @@ -572,23 +586,25 @@ def process_dataset_entry(self, full_audio_manifest_entry: Dict[str, Any]) -> Li "file_id": file_id, "segment_id": idx, "start_time": None, # No timing information - "end_time": None, # No timing information + "end_time": None, # No timing information } # Add speaker information if self.include_speaker_info: speaker_name = speaker_id # Default to ID - if (self.use_speaker_metadata_csv and - file_id in self.speaker_name_map and - speaker_id in self.speaker_name_map[file_id]): + if ( + self.use_speaker_metadata_csv + and file_id in self.speaker_name_map + and speaker_id in self.speaker_name_map[file_id] + ): speaker_name = self.speaker_name_map[file_id][speaker_id] manifest_entry_data["speaker"] = speaker_name - + # Add tags if requested if self.include_tags: segment_tags = [] segment_entities = [] - + # Extract basic tags from tokens for token in segment_dict.get('tokens', []): if token.get('tags') and str(token['tags']).strip(): @@ -596,12 +612,12 @@ def process_dataset_entry(self, full_audio_manifest_entry: Dict[str, Any]) -> Li tag_type = tag_val.split(':', 1)[1].strip() if ':' in tag_val else tag_val if tag_type and tag_type not in segment_tags: segment_tags.append(tag_type) - + manifest_entry_data["tags"] = segment_tags manifest_entry_data["entities"] = segment_entities output_entries.append(DataEntry(data=manifest_entry_data)) - + logger.info(f"Successfully processed {len(output_entries)} segments for file {file_id}") return output_entries @@ -660,23 +676,25 @@ def _parse_ctm_file(self, ctm_path: str) -> List[Dict[str, Any]]: duration = float(parts[3]) word = parts[4] end_time = start_time + duration - - alignments.append({ - 'word': word, - 'start': round(start_time, 3), - 'end': round(end_time, 3), - 'utt_id': utt_id, - 'channel': channel - }) + + alignments.append( + { + 'word': word, + 'start': round(start_time, 3), + 'end': round(end_time, 3), + 'utt_id': utt_id, + 'channel': channel, + } + ) except Exception as e: logger.error(f"Error parsing CTM file {ctm_path}: {e}") - + return alignments def _is_sentence_end(self, word: str, next_word: str = None) -> bool: """ Check if a word marks the end of a sentence. - + Rules: - Word ends with !, ?, or . - Exclude numbers like 42.12 (. within numbers) @@ -685,15 +703,15 @@ def _is_sentence_end(self, word: str, next_word: str = None) -> bool: """ if not word: return False - + # Check if word ends with sentence-ending punctuation if not word.endswith(('.', '!', '?')): return False - + # Handle exclamation and question marks - these are always sentence endings if word.endswith(('!', '?')): return True - + # For words ending with '.', do additional checks if word.endswith('.'): # Remove the final '.' and check if what remains is a number @@ -705,25 +723,100 @@ def _is_sentence_end(self, word: str, next_word: str = None) -> bool: except ValueError: # Not a number, continue with other checks pass - + # Check for common abbreviations (case-insensitive) common_abbreviations = { - 'mr', 'mrs', 'ms', 'dr', 'prof', 'sr', 'jr', 'vs', 'etc', 'inc', 'corp', 'ltd', 'co', - 'st', 'ave', 'blvd', 'rd', 'ln', 'ct', 'pl', 'sq', 'ft', 'in', 'cm', 'mm', 'kg', 'lb', - 'oz', 'pt', 'qt', 'gal', 'mph', 'rpm', 'vol', 'no', 'pg', 'pp', 'ch', 'sec', 'min', - 'hr', 'hrs', 'am', 'pm', 'est', 'pst', 'cst', 'mst', 'utc', 'gmt', 'jan', 'feb', 'mar', - 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec', 'mon', 'tue', 'wed', - 'thu', 'fri', 'sat', 'sun', 'dept', 'div', 'mgr', 'dir', 'pres', 'vp', 'ceo', 'cfo', - 'cto', 'coo', 'evp', 'svp', 'avp' + 'mr', + 'mrs', + 'ms', + 'dr', + 'prof', + 'sr', + 'jr', + 'vs', + 'etc', + 'inc', + 'corp', + 'ltd', + 'co', + 'st', + 'ave', + 'blvd', + 'rd', + 'ln', + 'ct', + 'pl', + 'sq', + 'ft', + 'in', + 'cm', + 'mm', + 'kg', + 'lb', + 'oz', + 'pt', + 'qt', + 'gal', + 'mph', + 'rpm', + 'vol', + 'no', + 'pg', + 'pp', + 'ch', + 'sec', + 'min', + 'hr', + 'hrs', + 'am', + 'pm', + 'est', + 'pst', + 'cst', + 'mst', + 'utc', + 'gmt', + 'jan', + 'feb', + 'mar', + 'apr', + 'may', + 'jun', + 'jul', + 'aug', + 'sep', + 'oct', + 'nov', + 'dec', + 'mon', + 'tue', + 'wed', + 'thu', + 'fri', + 'sat', + 'sun', + 'dept', + 'div', + 'mgr', + 'dir', + 'pres', + 'vp', + 'ceo', + 'cfo', + 'cto', + 'coo', + 'evp', + 'svp', + 'avp', } - + if word_without_dot.lower() in common_abbreviations: return False - + # If we have a next word, check if it starts with capital letter if next_word: return next_word[0].isupper() - + # If no next word, assume it's sentence end return True @@ -731,13 +824,13 @@ def _create_sentence_segments(self, alignments: List[Dict[str, Any]]) -> List[Di """Create sentence-level segments from word alignments.""" if not alignments: return [] - + segments = [] current_segment_words = [] - + for i, alignment in enumerate(alignments): current_segment_words.append(alignment) - + # Check if this word ends a sentence next_word = alignments[i + 1]['word'] if i + 1 < len(alignments) else None if self._is_sentence_end(alignment['word'], next_word): @@ -747,55 +840,59 @@ def _create_sentence_segments(self, alignments: List[Dict[str, Any]]) -> List[Di segment_start = current_segment_words[0]['start'] segment_end = current_segment_words[-1]['end'] segment_duration = round(segment_end - segment_start, 3) - - segments.append({ - 'text': segment_text, - 'start_time': segment_start, - 'end_time': segment_end, - 'duration': segment_duration, - 'alignment': current_segment_words.copy() - }) - + + segments.append( + { + 'text': segment_text, + 'start_time': segment_start, + 'end_time': segment_end, + 'duration': segment_duration, + 'alignment': current_segment_words.copy(), + } + ) + current_segment_words = [] - + # Handle any remaining words if current_segment_words: segment_text = ' '.join([w['word'] for w in current_segment_words]) segment_start = current_segment_words[0]['start'] segment_end = current_segment_words[-1]['end'] segment_duration = round(segment_end - segment_start, 3) - - segments.append({ - 'text': segment_text, - 'start_time': segment_start, - 'end_time': segment_end, - 'duration': segment_duration, - 'alignment': current_segment_words.copy() - }) - + + segments.append( + { + 'text': segment_text, + 'start_time': segment_start, + 'end_time': segment_end, + 'duration': segment_duration, + 'alignment': current_segment_words.copy(), + } + ) + return segments def process_dataset_entry(self, aligned_manifest_entry: Dict[str, Any]) -> List[DataEntry]: """Process a single aligned manifest entry to create sentence-level segments.""" file_id = aligned_manifest_entry['file_id'] audio_filepath = aligned_manifest_entry['audio_filepath'] - + # Find corresponding CTM file ctm_file = self.ctm_dir / f"{file_id}.ctm" if not ctm_file.exists(): logger.warning(f"CTM file not found: {ctm_file}") return [] - + # Parse CTM file alignments = self._parse_ctm_file(str(ctm_file)) if not alignments: logger.warning(f"No alignments found in CTM file: {ctm_file}") return [] - + # Create sentence segments segments = self._create_sentence_segments(alignments) logger.info(f"Created {len(segments)} sentence segments for file {file_id}") - + # Create manifest entries output_entries = [] for idx, segment in enumerate(segments): @@ -807,11 +904,11 @@ def process_dataset_entry(self, aligned_manifest_entry: Dict[str, Any]) -> List[ "segment_id": idx, "offset": segment['start_time'], # Use offset instead of start_time "end_time": segment['end_time'], - "alignment": segment['alignment'] + "alignment": segment['alignment'], } - + output_entries.append(DataEntry(data=manifest_entry_data)) - + logger.info(f"Successfully processed {len(output_entries)} sentence segments for file {file_id}") return output_entries @@ -863,15 +960,15 @@ def __init__( self.pretrained_name = pretrained_name self.device = device self.nemo_path = nemo_path - + # Create output directory self.output_dir.mkdir(parents=True, exist_ok=True) def process(self): """Process the manifest using NeMo Forced Aligner script.""" - import subprocess import json - + import subprocess + try: # Find NeMo forced aligner script if self.nemo_path: @@ -880,23 +977,24 @@ def process(self): # Try to find NeMo installation try: import nemo + nemo_dir = Path(nemo.__file__).parent.parent align_script = nemo_dir / "tools" / "nemo_forced_aligner" / "align.py" except ImportError: raise ImportError("NeMo not found. Please install NeMo or specify nemo_path.") - + if not align_script.exists(): raise FileNotFoundError(f"NeMo Forced Aligner script not found at {align_script}") - + logger.info(f"Using NeMo Forced Aligner script at: {align_script}") - + # Prepare manifest for forced alignment input_manifest = [] with open(self.input_manifest_file, 'r') as f: for line in f: if line.strip(): input_manifest.append(json.loads(line)) - + # Create temporary manifest with absolute paths temp_manifest_path = self.output_dir / "temp_manifest_for_alignment.json" with open(temp_manifest_path, 'w') as f: @@ -906,13 +1004,10 @@ def process(self): audio_path = Path(entry['audio_filepath']) if not audio_path.is_absolute(): audio_path = audio_path.resolve() - - alignment_entry = { - "audio_filepath": str(audio_path), - "text": entry['text'].strip() - } + + alignment_entry = {"audio_filepath": str(audio_path), "text": entry['text'].strip()} f.write(json.dumps(alignment_entry) + '\n') - + # Run NeMo Forced Aligner # Determine if we should use pretrained_name or model_path if self.pretrained_name.endswith('.nemo'): @@ -921,25 +1016,26 @@ def process(self): else: # Pretrained model name - use pretrained_name model_param = f"pretrained_name={self.pretrained_name}" - + cmd = [ - "python", str(align_script), + "python", + str(align_script), model_param, f"manifest_filepath={temp_manifest_path}", f"output_dir={self.output_dir}", f"transcribe_device={self.device}", f"viterbi_device={self.device}", "batch_size=1", - 'save_output_file_formats=["ctm"]' + 'save_output_file_formats=["ctm"]', ] - + logger.info(f"Running NeMo Forced Aligner: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True, check=True) logger.info("NeMo Forced Aligner completed successfully") - + # Process the output and merge with original manifest output_manifest_path = self.output_dir / f"{temp_manifest_path.stem}_with_output_file_paths.json" - + if output_manifest_path.exists(): # Load alignment results alignment_results = [] @@ -947,30 +1043,30 @@ def process(self): for line in f: if line.strip(): alignment_results.append(json.loads(line)) - + # Create mapping from audio filepath to alignment results alignment_map = {} for result in alignment_results: audio_path = result['audio_filepath'] alignment_map[audio_path] = result - + # Merge alignments with original manifest output_entries = [] for entry in input_manifest: output_entry = entry.copy() - + if entry.get('text', '').strip(): # Find corresponding alignment audio_path = str(Path(entry['audio_filepath']).resolve()) if audio_path in alignment_map: alignment_result = alignment_map[audio_path] - + # Load word-level CTM file if available if 'word_level_ctm_filepath' in alignment_result: ctm_path = alignment_result['word_level_ctm_filepath'] word_alignments = self._parse_ctm_file(ctm_path) output_entry['alignment'] = word_alignments - + # Calculate duration from alignments if word_alignments: output_entry['duration'] = round( @@ -987,23 +1083,23 @@ def process(self): else: output_entry['alignment'] = [] output_entry['duration'] = 0.0 - + output_entries.append(output_entry) - + # Save final manifest with open(self.output_manifest_file, 'w') as f: for entry in output_entries: f.write(json.dumps(entry) + '\n') - + logger.info(f"Saved aligned manifest to {self.output_manifest_file}") - + # Clean up temporary files temp_manifest_path.unlink(missing_ok=True) - + else: logger.error(f"Expected output file not found: {output_manifest_path}") raise FileNotFoundError(f"NeMo Forced Aligner did not produce expected output") - + except subprocess.CalledProcessError as e: logger.error(f"NeMo Forced Aligner failed: {e}") logger.error(f"stdout: {e.stdout}") @@ -1026,13 +1122,9 @@ def _parse_ctm_file(self, ctm_path: str) -> List[Dict[str, Any]]: duration = float(parts[3]) word = parts[4] end_time = start_time + duration - - alignments.append({ - 'word': word, - 'start': round(start_time, 3), - 'end': round(end_time, 3) - }) + + alignments.append({'word': word, 'start': round(start_time, 3), 'end': round(end_time, 3)}) except Exception as e: logger.error(f"Error parsing CTM file {ctm_path}: {e}") - - return alignments \ No newline at end of file + + return alignments diff --git a/sdp/processors/datasets/hifitts2/download_dataset.py b/sdp/processors/datasets/hifitts2/download_dataset.py index 493fdf97..f09008ed 100644 --- a/sdp/processors/datasets/hifitts2/download_dataset.py +++ b/sdp/processors/datasets/hifitts2/download_dataset.py @@ -14,12 +14,13 @@ import json -import librosa -from pathlib import Path -import soundfile as sf import time import urllib.error import urllib.request +from pathlib import Path + +import librosa +import soundfile as sf from sdp.logging import logger from sdp.processors.base_processor import BaseParallelProcessor, DataEntry @@ -138,8 +139,10 @@ def process_dataset_entry(self, data_entry): original_duration = data_entry["duration"] duration_diff = abs(chapter_duration - original_duration) if duration_diff > 0.1: - error_msg = f"Duration mismatch for {url}: original duration={original_duration}; " \ - f"downloaded duration={round(chapter_duration, 2)}" + error_msg = ( + f"Duration mismatch for {url}: original duration={original_duration}; " + f"downloaded duration={round(chapter_duration, 2)}" + ) logger.warning(error_msg) if self.exit_on_error: diff --git a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py index b4cd5a8b..18aead86 100644 --- a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py +++ b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py @@ -15,6 +15,7 @@ import json from pathlib import Path + from tqdm import tqdm from sdp.processors.base_processor import BaseProcessor diff --git a/sdp/processors/datasets/lhotse.py b/sdp/processors/datasets/lhotse.py index 01f54d44..338711c7 100644 --- a/sdp/processors/datasets/lhotse.py +++ b/sdp/processors/datasets/lhotse.py @@ -63,12 +63,8 @@ def process(self): def check_entry(self, cut) -> None: from lhotse import MonoCut - assert isinstance( - cut, MonoCut - ), f"Currently, only MonoCut import is supported. Received: {cut}" - assert ( - cut.has_recording - ), f"Currently, we only support cuts with recordings. Received: {cut}" + assert isinstance(cut, MonoCut), f"Currently, only MonoCut import is supported. Received: {cut}" + assert cut.has_recording, f"Currently, we only support cuts with recordings. Received: {cut}" assert ( cut.recording.num_channels == 1 ), f"Currently, we only supports recordings with a single channel. Received: {cut}" diff --git a/sdp/processors/datasets/librispeech/create_initial_manifest.py b/sdp/processors/datasets/librispeech/create_initial_manifest.py index 83d42bde..42dd6662 100644 --- a/sdp/processors/datasets/librispeech/create_initial_manifest.py +++ b/sdp/processors/datasets/librispeech/create_initial_manifest.py @@ -94,7 +94,7 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: entries = [] root = os.path.dirname(file_path) - print(f"Processing transcript file: {file_path}") + print(f"Processing transcript file: {file_path}") with open(file_path, encoding="utf-8") as fin: for line in fin: id, text = line[: line.index(" ")], line[line.index(" ") + 1 :] diff --git a/sdp/processors/datasets/masc/__init__.py b/sdp/processors/datasets/masc/__init__.py index 82fd7b35..c719f0c3 100644 --- a/sdp/processors/datasets/masc/__init__.py +++ b/sdp/processors/datasets/masc/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .create_initial_manifest import CreateInitialManifestMASC from .aggregate_segments import AggregateSegments from .apply_reg_exp_on_vtt_entries import RegExpVttEntries +from .create_initial_manifest import CreateInitialManifestMASC from .get_caption_file_segments import GetCaptionFileSegments diff --git a/sdp/processors/datasets/masc/aggregate_segments.py b/sdp/processors/datasets/masc/aggregate_segments.py index 8db51046..d5cac60e 100644 --- a/sdp/processors/datasets/masc/aggregate_segments.py +++ b/sdp/processors/datasets/masc/aggregate_segments.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import logging +import os + from pydub import AudioSegment + from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.processors.datasets.masc.utils import save_audio_segment + class AggregateSegments(BaseParallelProcessor): """ Aggregates short segments into segments with duration not longer than `max_duration`. The algorithm works by iterating from left to right, merging consecutive segments into the current segment until the total duration reaches `max_duration`. - + output_audio_dir (str): Directory where aggregated audio segments will be stored, if `save_aggregated_audio_segments` is True. If `save_aggregated_audio_segments` is False, this path is used to create the audio file paths in the manifest. input_segments_key (str): The field name that contains list of segments in the input manifest. Defaults to "segments". @@ -37,6 +40,7 @@ class AggregateSegments(BaseParallelProcessor): Defaults to True. verbose (bool): Set to True to enable more detailed logging. Defaults to False. """ + def __init__( self, output_audio_dir: str, @@ -74,17 +78,17 @@ def process_dataset_entry(self, data_entry: dict): segments = data_entry[self.input_segments_key] if len(segments) == 0: return [] - + audio = AudioSegment.from_wav(data_entry[self.input_audio_filepath_key]) - + audio_basename = os.path.basename(data_entry[self.input_audio_filepath_key]).split(".")[0] agg_segments = [] aggregated_segment = {**segments[0]} for segment in segments[1:]: # checking if adding another segment will cause the total duration to exceed max_duration - if (segment["end_time"] > audio.duration_seconds or segment["start_time"] > audio.duration_seconds): + if segment["end_time"] > audio.duration_seconds or segment["start_time"] > audio.duration_seconds: continue - + start_time = min(segment["start_time"], aggregated_segment["start_time"]) end_time = max(segment["end_time"], aggregated_segment["end_time"]) if end_time - start_time >= self.max_duration: @@ -96,36 +100,37 @@ def process_dataset_entry(self, data_entry: dict): aggregated_segment["text"] += f" {segment['text']}".strip() else: aggregated_segment["text"] = f"{segment['text']} {aggregated_segment['text']}" - - aggregated_segment["start"] = start_time # updating aggregated segment start time - aggregated_segment["end_time"] = end_time # updating aggregated segment end time + + aggregated_segment["start"] = start_time # updating aggregated segment start time + aggregated_segment["end_time"] = end_time # updating aggregated segment end time else: - # adding the last aggregated segment + # adding the last aggregated segment if aggregated_segment not in agg_segments: agg_segments.append(aggregated_segment) - + valid_segments = [] for aggregated_segment in agg_segments: aggregated_segment.update(data_entry) - + start_time = aggregated_segment.pop("start_time") end_time = aggregated_segment.pop("end_time") - + aggregated_segment[self.output_duration_key] = end_time - start_time - aggregated_segment[self.output_splitted_audio_filepath_key] = os.path.join(self.output_audio_dir, f"{audio_basename}_{start_time}_{end_time}.wav") - + aggregated_segment[self.output_splitted_audio_filepath_key] = os.path.join( + self.output_audio_dir, f"{audio_basename}_{start_time}_{end_time}.wav" + ) + if self.save_aggregated_audio_segments: try: save_audio_segment( audio=audio, start_time=start_time, end_time=end_time, - output_audio_filepath=aggregated_segment[self.output_splitted_audio_filepath_key] + output_audio_filepath=aggregated_segment[self.output_splitted_audio_filepath_key], ) valid_segments.append(aggregated_segment) except IndexError as e: if self.verbose: logging.warning(f"Invalid segment boundaries in {audio_basename}. Skipping...") - + return [DataEntry(data=segment) for segment in valid_segments] - \ No newline at end of file diff --git a/sdp/processors/datasets/masc/apply_reg_exp_on_vtt_entries.py b/sdp/processors/datasets/masc/apply_reg_exp_on_vtt_entries.py index 541e98eb..16126247 100644 --- a/sdp/processors/datasets/masc/apply_reg_exp_on_vtt_entries.py +++ b/sdp/processors/datasets/masc/apply_reg_exp_on_vtt_entries.py @@ -14,8 +14,10 @@ import os import re -import webvtt # pip install webvtt-py from typing import Dict + +import webvtt # pip install webvtt-py + from sdp.processors.base_processor import BaseParallelProcessor, DataEntry diff --git a/sdp/processors/datasets/masc/create_initial_manifest.py b/sdp/processors/datasets/masc/create_initial_manifest.py index 9563f723..af242def 100644 --- a/sdp/processors/datasets/masc/create_initial_manifest.py +++ b/sdp/processors/datasets/masc/create_initial_manifest.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import logging +import os from pathlib import Path + import pandas as pd from sox import Transformer from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.utils.common import extract_archive + class CreateInitialManifestMASC(BaseParallelProcessor): """ Processor for creating initial manifest for Massive Arabic Speech Corpus (MASC). \n Dataset link: https://ieee-dataport.org/open-access/masc-massive-arabic-speech-corpus. Prior to calling processor download the tarred dataset and store it under `raw_dataset_dir/masc.tar.gz`. - + Creates manifest from samples in . `dataset_dir/subsets/data_split.csv`. All meta information is kept. Args: @@ -78,7 +80,7 @@ def __init__( super().__init__(**kwargs) self.raw_dataset_dir = Path(raw_data_dir) self.data_split = data_split - + # in original dataset there are no train, dev and test splits. These are added to support end-to-end tests. if self.data_split == "train": self.data_split = "clean_train" @@ -86,21 +88,31 @@ def __init__( self.data_split = "clean_dev" if self.data_split == "test": self.data_split = "clean_test" - + self.extract_archive_dir = extract_archive_dir self.resampled_audios_dir = Path(resampled_audios_dir) self.already_extracted = already_extracted - + self.target_samplerate = target_samplerate self.target_nchannels = target_nchannels self.output_manifest_sample_id_key = output_manifest_sample_id_key self.output_manifest_vtt_filapath_key = output_manifest_vtt_filapath_key self.output_manifest_audio_filapath_key = output_manifest_audio_filapath_key - + self.verbose = verbose - data_split_values = ["train", "dev", "test", "clean_train", "clean_dev", "clean_test", "noisy_train", "noisy_dev", "noisy_test"] + data_split_values = [ + "train", + "dev", + "test", + "clean_train", + "clean_dev", + "clean_test", + "noisy_train", + "noisy_dev", + "noisy_test", + ] if self.data_split not in data_split_values: raise ValueError(f'Data split value must be from {data_split_values}. "{self.data_split}" was given.') @@ -120,7 +132,7 @@ def prepare(self): else: logging.info("Skipping dataset untarring...") self.dataset_dir = Path(self.extract_archive_dir) / "masc" - + self.vtts_dir = self.dataset_dir / "subtitles" self.audios_dir = self.dataset_dir / "audios" if self.data_split == "clean_train" or self.data_split == "noisy_train": @@ -136,9 +148,9 @@ def prepare(self): if not self.audios_dir.exists(): raise FileNotFoundError(f"{self.audios_dir} not found.") - + os.makedirs(self.resampled_audios_dir, exist_ok=True) - + def read_manifest(self): csv = pd.read_csv(self.csv_filepath) return [row.to_dict() for _, row in csv.iterrows()] @@ -148,11 +160,11 @@ def process_dataset_entry(self, sample_data): source_audio_filepath = self.audios_dir / f"{sample_id}.wav" target_audio_filepath = self.resampled_audios_dir / f"{sample_id}.wav" vtt_filepath = self.vtts_dir / f"{sample_id}.ar.vtt" - + # if source audio or vtt file do not exist skip if not (os.path.exists(source_audio_filepath) and os.path.exists(vtt_filepath)): return [] - + # if target audio exists skip resampling if not os.path.exists(target_audio_filepath): tfm = Transformer() @@ -170,5 +182,5 @@ def process_dataset_entry(self, sample_data): self.output_manifest_audio_filapath_key: str(target_audio_filepath), } ) - + return [DataEntry(data=sample_data)] diff --git a/sdp/processors/datasets/masc/get_caption_file_segments.py b/sdp/processors/datasets/masc/get_caption_file_segments.py index 745c6548..d9f5460f 100644 --- a/sdp/processors/datasets/masc/get_caption_file_segments.py +++ b/sdp/processors/datasets/masc/get_caption_file_segments.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import logging +import os + from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.processors.datasets.masc.utils import parse_captions + class GetCaptionFileSegments(BaseParallelProcessor): """ This class extracts subtitle information from .vtt (WebVTT) files. @@ -26,7 +28,7 @@ class GetCaptionFileSegments(BaseParallelProcessor): input_caption_file_key (str): The field name in the input manifest containing path to the caption file. output_segments_key (str): The field name to store segment information. Defaults to "segments". verbose (bool): Set true for outputing logging information. - + Returns: This processor adds an output_segments field to the input manifest with a list of segments. Each segment has a structure: @@ -37,6 +39,7 @@ class GetCaptionFileSegments(BaseParallelProcessor): "text": } """ + def __init__( self, input_caption_file_key: str, @@ -51,7 +54,7 @@ def __init__( def process_dataset_entry(self, data_entry): caption_file = data_entry[self.caption_file_key] - + if not os.path.exists(caption_file): if self.verbose: logging.info(f"File {caption_file} does not exist.") diff --git a/sdp/processors/datasets/masc/utils.py b/sdp/processors/datasets/masc/utils.py index e3adf646..5e57a637 100644 --- a/sdp/processors/datasets/masc/utils.py +++ b/sdp/processors/datasets/masc/utils.py @@ -12,36 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import webvtt # pip install webvtt-py +from datetime import datetime from typing import Optional + +import webvtt # pip install webvtt-py + from sdp.processors.datasets.commoncrawl.harv_utils import parse_hours -from datetime import datetime + def save_audio_segment(audio, start_time: float, end_time: float, output_audio_filepath: Optional[str]): """ Extracts a segment from audio. - + Args: audio: input audio start_time (float): segment start time in seconds. end_time (float): segment end time in seconds. audio_filepath (Optional[str]): filepath to store the segment. - + Returns: audio_segment: audio segment - + IndexError: Raised if segment boundaries are out of range. """ start_time = start_time * 1000 end_time = end_time * 1000 - + if start_time >= len(audio) or end_time >= len(audio): raise IndexError("Segment boundaries are out of range.") - + audio_segment = audio[start_time:end_time] if output_audio_filepath: audio_segment.export(output_audio_filepath, format="wav") - + return audio_segment @@ -55,7 +58,7 @@ def parse_captions(captions_filepath: str): "end_time": float, # End time of the segment (in seconds) "text": str # Text content of the segment } - + Args: captions_filepath (str): path to .vtt file. """ @@ -65,13 +68,13 @@ def parse_captions(captions_filepath: str): text = ' '.join([text.strip() for text in caption.text.split('\n')]) start_time = parse_hours(caption.start) - initial_timestamp end_time = parse_hours(caption.end) - initial_timestamp - + segment = { "segment_id": index, "start_time": start_time.total_seconds(), "end_time": end_time.total_seconds(), - "text": text + "text": text, } srt_segments.append(segment) - - return srt_segments \ No newline at end of file + + return srt_segments diff --git a/sdp/processors/datasets/mediaspeech/create_initial_manifest.py b/sdp/processors/datasets/mediaspeech/create_initial_manifest.py index cfba28d1..2514d44c 100644 --- a/sdp/processors/datasets/mediaspeech/create_initial_manifest.py +++ b/sdp/processors/datasets/mediaspeech/create_initial_manifest.py @@ -18,8 +18,7 @@ from sdp.logging import logger from sdp.processors.base_processor import BaseParallelProcessor, DataEntry -from sdp.utils.common import ffmpeg_convert -from sdp.utils.common import extract_archive +from sdp.utils.common import extract_archive, ffmpeg_convert class CreateInitialManifestMediaSpeech(BaseParallelProcessor): @@ -27,7 +26,7 @@ class CreateInitialManifestMediaSpeech(BaseParallelProcessor): Processor for creating initial manifest for MediaSpeech Arabic dataset. Dataset link: https://www.openslr.org/108/. Prior to calling processor download the tarred dataset and store it under `raw_dataset_dir/AR.tgz` or `raw_dataset_dir/AR.tar.gz`. - + Args: raw_data_dir (str): The root directory of the dataset. extract_archive_dir (str): Directory where the extracted data will be saved. @@ -42,12 +41,13 @@ class CreateInitialManifestMediaSpeech(BaseParallelProcessor): Returns: This processor generates an initial manifest file with the following fields:: - + { "audio_filepath": , "text": , } """ + def __init__( self, raw_data_dir: str, @@ -66,10 +66,10 @@ def __init__( self.extract_archive_dir = extract_archive_dir self.resampled_audios_dir = Path(resampled_audios_dir) self.already_extracted = already_extracted - + self.target_samplerate = target_samplerate self.target_nchannels = target_nchannels - + self.output_manifest_sample_id_key = output_manifest_sample_id_key self.output_manifest_audio_filapath_key = output_manifest_audio_filapath_key self.output_manifest_text_key = output_manifest_text_key @@ -94,7 +94,7 @@ def prepare(self): else: logger.info("Skipping dataset untarring...") self.dataset_dir = Path(self.extract_archive_dir) / "AR" - + os.makedirs(self.resampled_audios_dir, exist_ok=True) def read_manifest(self): @@ -106,9 +106,7 @@ def read_manifest(self): text_filepath = f"{self.dataset_dir}/{sample_id}.txt" if not os.path.exists(text_filepath): - logger.warning( - f'Sample "{sample_id}" has no related .txt files. Skipping' - ) + logger.warning(f'Sample "{sample_id}" has no related .txt files. Skipping') continue data_entries.append( @@ -142,4 +140,4 @@ def process_dataset_entry(self, data_entry: DataEntry): text_file = open(data_entry["text_filepath"], "r") data[self.output_manifest_text_key] = text_file.read() - return [DataEntry(data=data)] \ No newline at end of file + return [DataEntry(data=data)] diff --git a/sdp/processors/datasets/mls/create_initial_manifest.py b/sdp/processors/datasets/mls/create_initial_manifest.py index 72d9c00d..c9fdbd79 100644 --- a/sdp/processors/datasets/mls/create_initial_manifest.py +++ b/sdp/processors/datasets/mls/create_initial_manifest.py @@ -103,8 +103,6 @@ def __init__( " specified `target_samplerate` or `target_nchannels`, they will be ignored." ) - - # will be initialized in self.prepare method self.audio_path_prefix = None self.transcription_file = None diff --git a/sdp/processors/datasets/mtedx/create_initial_manifest.py b/sdp/processors/datasets/mtedx/create_initial_manifest.py index 8c39257b..54d40419 100644 --- a/sdp/processors/datasets/mtedx/create_initial_manifest.py +++ b/sdp/processors/datasets/mtedx/create_initial_manifest.py @@ -1,45 +1,49 @@ import os from pathlib import Path from typing import List + import librosa + from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.utils.common import download_file, extract_archive MTEDX_URL = "https://www.openslr.org/resources/100/mtedx_{language_id}.tgz" + class CreateInitialManifestMTEDX(BaseParallelProcessor): """Processor to create initial manifest for the Multilingual TEDx (MTedX dataset. - Dataset link: https://www.openslr.org/100/ - - Downloads dataset for the specified language and creates initial manifest with the provided - audio and vtt files. - - Args: - raw_data_dir (str): the directory where the downloaded data will be/is saved. - This is also where the extracted and processed data will be. - data_split (str): "train", "dev" or "test". - language_id (str): the ID of the language of the data. E.g., "en", "es", "it", etc. - target_samplerate (int): sample rate (Hz) to use for resampling. - already_extracted: (bool): if True, we will not try to extract the raw data. - Defaults to False. - - Returns: - This processor generates an initial manifest file with the following fields:: - - { - "audio_filepath": , - "vtt_filepath": - "duration": - } - """ + Dataset link: https://www.openslr.org/100/ + + Downloads dataset for the specified language and creates initial manifest with the provided + audio and vtt files. + + Args: + raw_data_dir (str): the directory where the downloaded data will be/is saved. + This is also where the extracted and processed data will be. + data_split (str): "train", "dev" or "test". + language_id (str): the ID of the language of the data. E.g., "en", "es", "it", etc. + target_samplerate (int): sample rate (Hz) to use for resampling. + already_extracted: (bool): if True, we will not try to extract the raw data. + Defaults to False. + + Returns: + This processor generates an initial manifest file with the following fields:: + + { + "audio_filepath": , + "vtt_filepath": + "duration": + } + """ + def __init__( - self, - raw_data_dir: str, - language_id: str, - data_split: str, - already_extracted: bool = False, - **kwargs, + self, + raw_data_dir: str, + language_id: str, + data_split: str, + already_extracted: bool = False, + **kwargs, ): super().__init__(**kwargs) self.raw_data_dir = Path(raw_data_dir) @@ -51,15 +55,14 @@ def prepare(self): """Downloading and extracting data (unless already done).""" os.makedirs(self.raw_data_dir, exist_ok=True) - url = MTEDX_URL.format(language_id=self.language_id) if not (self.raw_data_dir / f"mtedx_{self.language_id}.tgz").exists(): download_file(url, str(self.raw_data_dir)) if not self.already_extracted: extract_archive(str(self.raw_data_dir / os.path.basename(url)), str(self.raw_data_dir)) - - data_folder = Path(self.raw_data_dir) / f"{self.language_id}-{self.language_id}"/ "data"/ self.data_split + + data_folder = Path(self.raw_data_dir) / f"{self.language_id}-{self.language_id}" / "data" / self.data_split self.audio_path_prefix = Path(data_folder) / "wav" self.vtt_path_prefix = Path(data_folder) / "vtt" @@ -67,7 +70,9 @@ def read_manifest(self): """Creating entries of initial manifest with flac and vtt files""" audio_filepaths = [] for audio_file in os.listdir(self.audio_path_prefix): - vtt_filepath = os.path.join(self.vtt_path_prefix, audio_file.split('.')[0] + "." + self.language_id + ".vtt") + vtt_filepath = os.path.join( + self.vtt_path_prefix, audio_file.split('.')[0] + "." + self.language_id + ".vtt" + ) audio_filepath = os.path.join(self.audio_path_prefix, audio_file) audio_filepaths.append((audio_filepath, vtt_filepath)) return audio_filepaths diff --git a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py index 27117f2a..51b95120 100644 --- a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py +++ b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py @@ -16,11 +16,12 @@ import json import os import typing + import gdown +from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import extract_archive -from sdp.logging import logger class CreateInitialManifestUzbekvoice(BaseProcessor): @@ -30,7 +31,7 @@ class CreateInitialManifestUzbekvoice(BaseProcessor): Will download all files, extract them, and create a manifest file with the "audio_filepath", "text" and "duration" fields. - Args: + Args: raw_data_dir (str): Path to the folder where the data archive should be downloaded and extracted. Returns: @@ -59,8 +60,10 @@ def download_extract_files(self, dst_folder: str) -> None: # for big files google drive doesn't allow to try downlaoding them more than once # so, in case of receiveing gdown error we need to download them manually - #check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder - if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists(os.path.join(dst_folder, 'uzbekvoice-dataset.zip')): + # check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder + if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists( + os.path.join(dst_folder, 'uzbekvoice-dataset.zip') + ): print("Files already exist in the folder. Skipping download.") else: print(f"Downloading files from {self.URL}...") @@ -74,7 +77,6 @@ def download_extract_files(self, dst_folder: str) -> None: extract_archive(file, str(dst_folder), force_extract=True) print(f"Extracted {file}") - def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: """ Parse transcript JSON file and put it inside manifest. @@ -93,13 +95,8 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: utter_length = entry["clip_duration"] number_of_entries += 1 entries.append( - { - "audio_filepath": os.path.abspath(audio_file), - "text": transcript, - "duration": utter_length - } + {"audio_filepath": os.path.abspath(audio_file), "text": transcript, "duration": utter_length} ) - logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) @@ -113,8 +110,6 @@ def process_data(self, data_folder: str, manifest_file: str) -> None: for m in entries: fout.write(json.dumps(m, ensure_ascii=False) + "\n") - - def process(self): self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) diff --git a/sdp/processors/datasets/ytc/create_initial_manifest.py b/sdp/processors/datasets/ytc/create_initial_manifest.py index 77b07c83..7ef4f468 100644 --- a/sdp/processors/datasets/ytc/create_initial_manifest.py +++ b/sdp/processors/datasets/ytc/create_initial_manifest.py @@ -19,9 +19,10 @@ from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.utils.common import load_manifest + class CreateInitialManifestYTC(BaseParallelProcessor): """A processor class for creating initial manifest files for a TTS dataset. - + It takes a manifest file containing audio file paths and resamples them to a target sample rate and format, while creating a new manifest file with the updated paths. @@ -43,15 +44,16 @@ class CreateInitialManifestYTC(BaseParallelProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_resampled.json """ + def __init__( - self, - input_format: str, - resampled_audio_dir: str, - target_sample_rate: int, - target_format: str, - target_nchannels: int, - **kwargs - ): + self, + input_format: str, + resampled_audio_dir: str, + target_sample_rate: int, + target_format: str, + target_nchannels: int, + **kwargs, + ): super().__init__(**kwargs) self.input_format = input_format self.resampled_audio_dir = resampled_audio_dir @@ -64,9 +66,9 @@ def prepare(self): os.makedirs(self.resampled_audio_dir, exist_ok=True) def read_manifest(self): - """ Reads metadata from JSONL file in the input manifest - Returns: - list: A list of dataset entries parsed from the JSONL manifest file + """Reads metadata from JSONL file in the input manifest + Returns: + list: A list of dataset entries parsed from the JSONL manifest file """ dataset_entries = load_manifest(self.input_manifest_file, encoding="utf8") @@ -74,13 +76,13 @@ def read_manifest(self): def process_dataset_entry(self, metadata: DataEntry): """Processes a single dataset entry by resampling the audio file and updating metadata. - + Args: metadata (DataEntry): The metadata entry containing information about the audio file - + Returns: list[DataEntry]: A list containing the processed DataEntry with updated metadata - + Note: This method: 1. Resamples the audio file to the target format and sample rate if needed @@ -88,8 +90,11 @@ def process_dataset_entry(self, metadata: DataEntry): 3. Uses either sox or ffmpeg for audio conversion depending on input format """ import soundfile as sf + input_audio_path = metadata['audio_filepath'] - output_audio_path = os.path.join(self.resampled_audio_dir, metadata['audio_item_id'] + '.' + self.target_format) + output_audio_path = os.path.join( + self.resampled_audio_dir, metadata['audio_item_id'] + '.' + self.target_format + ) # Convert audio file to target sample rate and format if not os.path.exists(output_audio_path): @@ -98,21 +103,19 @@ def process_dataset_entry(self, metadata: DataEntry): else: cmd = f'ffmpeg -i "{input_audio_path}" -ar {self.target_sample_rate} -ac 1 -ab 16 "{output_audio_path}" -v error' try: - subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True) # Ensures output is in string formats) + subprocess.run( + cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) # Ensures output is in string formats) except subprocess.CalledProcessError as e: print("Exception occurred while converting audio file: ", e, e.stderr) print(f'Error converting {input_audio_path} to {output_audio_path}. Hence skipping this entry.') exit(1) - + metadata['audio_filepath'] = input_audio_path metadata['resampled_audio_filepath'] = output_audio_path try: metadata['duration'] = sf.info(output_audio_path).duration except Exception as e: print(f'Error getting duration of {output_audio_path}. Hence not adding duration to metadata.') - - return [DataEntry(data=metadata)] - + return [DataEntry(data=metadata)] diff --git a/sdp/processors/huggingface/create_initial_manifest.py b/sdp/processors/huggingface/create_initial_manifest.py index e8abfb55..8d392f26 100644 --- a/sdp/processors/huggingface/create_initial_manifest.py +++ b/sdp/processors/huggingface/create_initial_manifest.py @@ -1,11 +1,12 @@ -import os import glob +import os +from typing import Optional import soundfile as sf -from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.logging import logger -from typing import Optional +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry + class CreateInitialManifestHuggingFace(BaseParallelProcessor): """Processor to create initial manifest for HuggingFace dataset. @@ -24,7 +25,7 @@ class CreateInitialManifestHuggingFace(BaseParallelProcessor): Returns: This processor generates an initial manifest file with the following fields:: - + { "audio_filepath": , "duration": , @@ -55,7 +56,7 @@ def prepare(self): def read_manifest(self): import datasets - + # checking if dataset should be loaded from disk if self.already_downloaded: if os.path.exists(self.raw_data_dir): @@ -89,4 +90,4 @@ def process_dataset_entry(self, data_id): "text": text, } ) - ] \ No newline at end of file + ] diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/huggingface/speech_recognition.py index 2e64e7c4..3e943df2 100644 --- a/sdp/processors/huggingface/speech_recognition.py +++ b/sdp/processors/huggingface/speech_recognition.py @@ -14,13 +14,14 @@ import json from pathlib import Path +from typing import Optional from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest -from typing import Optional + class ASRTransformers(BaseProcessor): """This processor transcribes audio files using HuggingFace ASR Transformer models. @@ -99,7 +100,7 @@ def __init__( # Check if using Whisper/Seamless or NVIDIA model based on the model name self.is_whisper_or_seamless = any(x in self.pretrained_model.lower() for x in ['whisper', 'seamless']) - + # Only set language in generation config for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language: self.model.generation_config.language = self.generate_language @@ -131,7 +132,7 @@ def process(self): batch = json_list_sorted[start_index : start_index + self.batch_size] start_index += self.batch_size audio_files = [item[self.input_audio_key] for item in batch] - + # Only pass generate_kwargs for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language and self.generate_task: results = self.pipe( diff --git a/sdp/processors/ipl/ipl_processors.py b/sdp/processors/ipl/ipl_processors.py index 0a29c656..c9240a11 100644 --- a/sdp/processors/ipl/ipl_processors.py +++ b/sdp/processors/ipl/ipl_processors.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import logging + # Standard library imports import os import subprocess @@ -20,8 +23,7 @@ # Third-party imports from omegaconf import DictConfig, OmegaConf, open_dict -import logging -import json + # Local imports from sdp.processors.base_processor import BaseProcessor @@ -43,13 +45,13 @@ class TrainingCommandGenerator(BaseProcessor): def __init__( self, - training_config_local: str, # Local machine config path - training_config_cluster: str, # Cluster config path - training_script_path: str, # Path to training script - nemo_directory: str, # Base directory for NeMo + training_config_local: str, # Local machine config path + training_config_cluster: str, # Cluster config path + training_script_path: str, # Path to training script + nemo_directory: str, # Base directory for NeMo new_manifest_files: str = None, # New manifest files to add new_tarred_audio_filepaths: str = None, # New tarred audio paths - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -61,11 +63,7 @@ def __init__( self.new_manifest_files = new_manifest_files self.new_tarred_audio_filepaths = new_tarred_audio_filepaths - def process( - self, - new_manifest_files=None, - new_tarred_audio_filepaths=None - ) -> str: + def process(self, new_manifest_files=None, new_tarred_audio_filepaths=None) -> str: """ Generates the training command based on the processor's configuration. If new manifest files are provided, updates the training configuration accordingly. @@ -77,20 +75,20 @@ def process( cmd = self.get_execution_script( cluster_script_path=self.training_script_path, local_config=self.training_config_local, - cluster_config_path=self.training_config_cluster + cluster_config_path=self.training_config_cluster, ) else: updated_manifest_filepaths, updated_tarred_audio_filepaths = self.update_training_sets( config=self.training_config_local, updated_manifest_filepaths=new_manifest_files, - updated_tarred_audio_filepaths=new_tarred_audio_filepaths + updated_tarred_audio_filepaths=new_tarred_audio_filepaths, ) cmd = self.get_execution_script( cluster_script_path=self.training_script_path, local_config=self.training_config_local, cluster_config_path=self.training_config_cluster, updated_manifest_filepaths=updated_manifest_filepaths, - updated_tarred_filepaths=updated_tarred_audio_filepaths + updated_tarred_filepaths=updated_tarred_audio_filepaths, ) return cmd @@ -100,7 +98,7 @@ def get_execution_script( local_config: DictConfig, cluster_config_path: str, updated_manifest_filepaths: Optional[str] = None, - updated_tarred_filepaths: Optional[str] = None + updated_tarred_filepaths: Optional[str] = None, ) -> str: """ Create the command to run the script on the cluster. @@ -127,7 +125,6 @@ def get_execution_script( "Please set WANDB_API_KEY to enable WANDB logging." ) - config_path = os.path.dirname(cluster_config_path) config_name = os.path.basename(cluster_config_path) cmd = ( @@ -147,8 +144,8 @@ def get_execution_script( # with open(self.output_manifest_file, 'w') as f: # json.dump(output_data, f, indent=4) return cmd - - def get_transcribed_names(self, manifest_filepaths: List[str], is_tarred: bool=False) -> List[List[str]]: + + def get_transcribed_names(self, manifest_filepaths: List[str], is_tarred: bool = False) -> List[List[str]]: """ Generates a list of modified file paths by prepending 'transcribed_' to the filenames. The use case is for non AIStore datasets @@ -174,11 +171,8 @@ def get_transcribed_names(self, manifest_filepaths: List[str], is_tarred: bool=F for file_path in manifest_filepaths: directory, filename = os.path.split(file_path) - - new_filename = ( - f"transcribed_{filename}" if is_tarred - else f"transcribed_manifest.json" - ) + + new_filename = f"transcribed_{filename}" if is_tarred else f"transcribed_manifest.json" transcribed_paths.append([os.path.join(directory, new_filename)]) return transcribed_paths @@ -187,7 +181,7 @@ def update_training_sets( self, config: DictConfig, updated_manifest_filepaths: List[str], - updated_tarred_audio_filepaths: Optional[List[str]] = None + updated_tarred_audio_filepaths: Optional[List[str]] = None, ) -> Tuple[str, str]: """ Updates the training dataset configuration by adding pseudo-labeled datasets @@ -203,7 +197,9 @@ def update_training_sets( - Updated manifest file paths as a string, formatted for Omegaconf - Updated tarred audio file paths as a string, formatted for Omegaconf """ - updated_manifest_filepaths = self.get_transcribed_names(updated_manifest_filepaths,is_tarred=config.model.train_ds.get("is_tarred", False)) + updated_manifest_filepaths = self.get_transcribed_names( + updated_manifest_filepaths, is_tarred=config.model.train_ds.get("is_tarred", False) + ) manifest_filepath = config.model.train_ds.manifest_filepath if updated_tarred_audio_filepaths: updated_tarred_audio_filepaths = [[path] for path in updated_tarred_audio_filepaths] @@ -255,13 +251,13 @@ class InferenceCommandGenerator(BaseProcessor): def __init__( self, - nemo_directory: str, + nemo_directory: str, inference_config_paths: str, - manifests: str, + manifests: str, p_cache: float, num_gpus: int, is_tarred: bool = False, - **kwargs + **kwargs, ): super().__init__(**kwargs) @@ -274,7 +270,7 @@ def __init__( self.num_gpus = num_gpus self.is_tarred = is_tarred - def process(self, first_run=False): + def process(self, first_run=False): """ Generate the pseudo-labeling command for the given configuration and training parameters. @@ -286,10 +282,14 @@ def process(self, first_run=False): """ cmd = "" prediction_directories_str = " ".join([os.path.dirname(path) for path in self.manifests]) - inference_config_paths_str = " ".join(self.inference_config_paths) - write_transcription_path = os.path.join(self.nemo_directory, "scripts/pseudo_labeling/write_transcribed_files.py") - update_inference_config_path = os.path.join(self.nemo_directory, "scripts/pseudo_labeling/update_inference_config.py") - + inference_config_paths_str = " ".join(self.inference_config_paths) + write_transcription_path = os.path.join( + self.nemo_directory, "scripts/pseudo_labeling/write_transcribed_files.py" + ) + update_inference_config_path = os.path.join( + self.nemo_directory, "scripts/pseudo_labeling/update_inference_config.py" + ) + if first_run: cmd += f" && {self.get_pl_inference_command(self.inference_config_paths, shuffle=False)}" cmd += ( @@ -304,10 +304,7 @@ def process(self, first_run=False): ) else: cmd += f" && {self.get_pl_inference_command(self.inference_config_paths, shuffle=True)}" - cmd += ( - f" && python {write_transcription_path} " - f"--prediction_filepaths {prediction_directories_str} " - ) + cmd += f" && python {write_transcription_path} " f"--prediction_filepaths {prediction_directories_str} " if self.is_tarred: cmd += " --is_tarred" @@ -317,7 +314,6 @@ def process(self, first_run=False): return cmd - def get_pl_inference_command(self, inference_configs, shuffle=None): """ Generate a command to run PL inference with multiple configuration files. @@ -338,4 +334,3 @@ def get_pl_inference_command(self, inference_configs, shuffle=None): cmd_list.append(cmd) return " && ".join(cmd_list) - diff --git a/sdp/processors/ipl/nemo_run_processor.py b/sdp/processors/ipl/nemo_run_processor.py index 529a128c..71198f9c 100644 --- a/sdp/processors/ipl/nemo_run_processor.py +++ b/sdp/processors/ipl/nemo_run_processor.py @@ -1,4 +1,3 @@ - # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,34 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sdp.processors.ipl.ipl_processors import TrainingCommandGenerator, InferenceCommandGenerator -from sdp.processors.base_processor import BaseProcessor -from omegaconf import OmegaConf, open_dict +import datetime +import logging import os from pathlib import Path -import logging -import datetime + import nemo_run as run +from omegaconf import OmegaConf, open_dict + +from sdp.processors.base_processor import BaseProcessor +from sdp.processors.ipl.ipl_processors import ( + InferenceCommandGenerator, + TrainingCommandGenerator, +) from sdp.utils import nemo_run_utils + class NemoRunIPLProcessor(BaseProcessor): """ A processor that handles Iterative Pseudo-Labeling (IPL) training workflow. - + Args: config_path (str): Path to the YAML configuration file containing IPL settings output_manifest_file (str): Path where the output manifest file will be written input_manifest_file (str, Optional): Path to the input manifest file """ - - def __init__( - self, - config_path: str, - **kwargs - ): + + def __init__(self, config_path: str, **kwargs): super().__init__(**kwargs) self.config_path = config_path - + def process(self): """ Main processing method that implements the IPL workflow. @@ -51,7 +52,7 @@ def process(self): """ # Load the cluster config from YAML cluster_cfg = OmegaConf.load(self.config_path) - + # Process the required arguments from the cluster config script_path = cluster_cfg.script script_config_path = cluster_cfg.script_config @@ -82,7 +83,7 @@ def process(self): if "ipl_training" not in script_config.model: raise KeyError("Parameters for `IPL` training are not provided.") # Check all paths in configs are properly mounted - + self.check_config_mount_paths(script_config, cluster_cfg) # Resolve experiment name exp_name = cluster_cfg.exp_name @@ -102,7 +103,9 @@ def process(self): # Copy the merged config file to remote location's /results/configs directory config_dir = os.path.join(results_dir, 'configs') - train_config_cluster = nemo_run_utils.create_remote_config(script_config, config_name, config_dir, cluster_cfg) + train_config_cluster = nemo_run_utils.create_remote_config( + script_config, config_name, config_dir, cluster_cfg + ) # Get run parameters from the config num_runs = cluster_cfg.num_runs @@ -119,7 +122,7 @@ def process(self): os.path.join(script_config.exp_manager.exp_dir, script_config.exp_manager.name), "checkpoints" ) checkpoint_name = os.path.join(checkpoint_dir, script_config.exp_manager.name + ".nemo") - + # Create remote inference config if do_average: avg_cmd, averaged_checkpoint = self.average_checkpoints(checkpoint_name, nemo_root) @@ -131,13 +134,13 @@ def process(self): ) self.check_config_mount_paths(inference_config, cluster_cfg) # Configure command generators - train_command_generator_config = { + train_command_generator_config = { "nemo_directory": nemo_root, "training_config_local": script_config, "training_config_cluster": train_config_cluster, "training_script_path": script_path, "output_manifest_file": "./train_output_manifest_filepath.json", - } + } inference_command_generator_config = { "nemo_directory": nemo_root, "inference_config_paths": inference_config_paths, @@ -157,7 +160,7 @@ def process(self): new_manifest_files=manifests, new_tarr_files=tarr_paths, first_run=True, - avg_cmd=avg_cmd + avg_cmd=avg_cmd, ) # Cast the cluster config to a dictionary for compatibility with NeMo Run @@ -175,7 +178,7 @@ def process(self): num_ipl_epochs=cluster_cfg['num_ipl_epochs'], new_manifest_files=manifests, new_tarr_files=tarr_paths, - first_run=False + first_run=False, ) task = [task] @@ -199,7 +202,7 @@ def process(self): def gather_mounts(self, cluster_cfg): """ Gather all mounts from the cluster config including ones which are disjoint from the cluster_cfg.mounts list. - + Args: cluster_cfg: Cluster config dictionary """ @@ -210,7 +213,9 @@ def gather_mounts(self, cluster_cfg): with open_dict(cluster_cfg): for k in keys: if k.startswith("mount_"): - logging.info(f"Found additional mount flag in the cluster config `{k}`. Adding it to the mounts list.") + logging.info( + f"Found additional mount flag in the cluster config `{k}`. Adding it to the mounts list." + ) mounts.append(cluster_cfg[k]) del cluster_cfg[k] @@ -220,11 +225,12 @@ def gather_mounts(self, cluster_cfg): def check_config_mount_paths(self, script_config, cluster_config): """ Check if all path-like strings in the script config are mounted paths in the cluster config. - + Args: script_config: Script config dictionary cluster_config: Cluster config dictionary """ + def filepath_check(v, cluster_cfg): if v.startswith(os.path.sep): logging.info(f"Checking if {v} is a mounted path") @@ -254,7 +260,7 @@ def get_pseudo_labeling_command( new_manifest_files, new_tarr_files, first_run: bool = False, - avg_cmd: str = None + avg_cmd: str = None, ) -> str: """ Generate the pseudo-labeling command for the given configuration and training parameters. @@ -273,12 +279,14 @@ def get_pseudo_labeling_command( train_proc = TrainingCommandGenerator(**train_command_config) infer_proc = InferenceCommandGenerator(**inference_command_config) - exec_cmd = self.get_export_variables_cmd(train_command_config["training_config_local"], train_command_config["nemo_directory"]) + exec_cmd = self.get_export_variables_cmd( + train_command_config["training_config_local"], train_command_config["nemo_directory"] + ) exec_cmd += train_proc.process() exec_cmd += " && sleep 10" if avg_cmd: exec_cmd += " && " + avg_cmd - + exec_cmd += " " + infer_proc.process(first_run=first_run) for _ in range(num_ipl_epochs): @@ -290,7 +298,7 @@ def get_pseudo_labeling_command( return exec_cmd - def get_export_variables_cmd(self, merged_cfg , nemo_root): + def get_export_variables_cmd(self, merged_cfg, nemo_root): """Generate command to export required environment variables.""" wandb_key = os.environ.get("WANDB_API_KEY") or os.environ.get("WANDB") or os.environ.get("WANDB_KEY", "") if not wandb_key: @@ -306,29 +314,30 @@ def get_export_variables_cmd(self, merged_cfg , nemo_root): "nvidia-smi && " f"export PYTHONPATH={nemo_root} && " f"export HF_TOKEN={os.getenv('HF_TOKEN', '')} && " - f"export WANDB_API_KEY={wandb_key} && ") - + f"export WANDB_API_KEY={wandb_key} && " + ) + return cmd - - def average_checkpoints(self, checkpoint_path: str, nemo_root:str) -> str: + + def average_checkpoints(self, checkpoint_path: str, nemo_root: str) -> str: """ Generates the command to average all checkpoints in the given directory and returns the path to the averaged checkpoint. - + Args: checkpoint_path (str): Path to the directory containing checkpoints - + Returns: tuple: (command to run, path to the averaged checkpoint file) """ # Get the directory containing the checkpoints checkpoint_dir = os.path.dirname(checkpoint_path) - + # Construct the command for checkpoint averaging cmd = f"python {nemo_root}/scripts/checkpoint_averaging/legacy/checkpoint_averaging.py {checkpoint_dir}" - + # The averaged checkpoint will have the same name but with '-averaged' suffix checkpoint_name = os.path.basename(checkpoint_path) base_name = os.path.splitext(checkpoint_name)[0] averaged_checkpoint = os.path.join(checkpoint_dir, f"{base_name}-averaged.nemo") - + return cmd, averaged_checkpoint diff --git a/sdp/processors/langs/arabic.py b/sdp/processors/langs/arabic.py index 2ebe444b..a5d03bab 100644 --- a/sdp/processors/langs/arabic.py +++ b/sdp/processors/langs/arabic.py @@ -56,23 +56,24 @@ LAM_ALEF_HAMZA_ABOVE = u'\uFEF7' LAM_ALEF_HAMZA_BELOW = u'\uFEF9' LAM_ALEF_MADDA_ABOVE = u'\uFEF5' -LIGATURES=(LAM_ALEF, LAM_ALEF_HAMZA_ABOVE, LAM_ALEF_HAMZA_BELOW, LAM_ALEF_MADDA_ABOVE) +LIGATURES = (LAM_ALEF, LAM_ALEF_HAMZA_ABOVE, LAM_ALEF_HAMZA_BELOW, LAM_ALEF_MADDA_ABOVE) # Punctuation marks QUESTION_MARK = "\u061F" SAMICOLON = "\u061B" COMMA = "\u060C" -DIACRITICS = [chr(x) for x in range(0x0600, 0x06ff) if unicodedata.category(chr(x)) == "Mn"] -PUNCTUATION_MARKS = ["?", "!", ":", ";", "-", ".", ",", "؟","،", "؛"] +DIACRITICS = [chr(x) for x in range(0x0600, 0x06FF) if unicodedata.category(chr(x)) == "Mn"] +PUNCTUATION_MARKS = ["?", "!", ":", ";", "-", ".", ",", "؟", "،", "؛"] ALEFS = (ALEF, ALEF_MADDA, ALEF_HAMZA_ABOVE, ALEF_HAMZA_BELOW) + class ArabicTextPreprocessor(BaseParallelProcessor): """Class for Arabic text preprocessing. Operates on the text in the ``input_text_key``, and saves output text in the ``output_text_key``. - + Args: input_text_key (str): the text field that will be the input to the processor. output_text_key (str): the text field that will contain processed text. @@ -92,6 +93,7 @@ class ArabicTextPreprocessor(BaseParallelProcessor): normalization of ligatures: `LAM_ALEF`, `LAM_ALEF_HAMZA_ABOVE`, `LAM_ALEF_HAMZA_BELOW`, `LAM_ALEF_MADDA_ABOVE` ligatures will be replaces by two letters `LAM` and `ALEF`. letter `TEH_MARBUTA` will be replaced by `HEH`. Defaults to False. """ + def __init__( self, input_text_key: str = "text", @@ -120,9 +122,7 @@ def __init__( self.apply_nfkc = apply_nfkc def process_dataset_entry(self, data_entry): - data_entry[self.output_text_key] = self.clean_data( - data_entry[self.input_text_key] - ) + data_entry[self.output_text_key] = self.clean_data(data_entry[self.input_text_key]) return [DataEntry(data=data_entry)] def _remove_diacritics(self, text): @@ -138,11 +138,11 @@ def _remove_punctuation(self, text): def _normalize_teh(self, text): text = text.replace(TEH_MARBUTA, HEH) return text - + def _normalize_ligature(self, text): LIGUATURES_PATTERN = re.compile(u"[" + u"".join(LIGATURES) + u"]", re.UNICODE) return LIGUATURES_PATTERN.sub(u'%s%s' % (LAM, ALEF), text) - + def _normalize_alef(self, text): ALEFS_PATTERN = re.compile(u"[" + u"".join(ALEFS) + u"]", re.UNICODE) return re.sub(ALEFS_PATTERN, ALEF, text) @@ -180,4 +180,4 @@ def clean_data(self, text): text = self._normalize(text) if self.apply_nfkc: text = unicodedata.normalize("NFKC", text) - return text \ No newline at end of file + return text diff --git a/sdp/processors/manage_files/convert_audio.py b/sdp/processors/manage_files/convert_audio.py index 3fbd3870..692e4de9 100644 --- a/sdp/processors/manage_files/convert_audio.py +++ b/sdp/processors/manage_files/convert_audio.py @@ -14,11 +14,11 @@ import os from typing import Optional + from sox import Transformer from sdp.logging import logger from sdp.processors.base_processor import BaseParallelProcessor, DataEntry - from sdp.utils.common import ffmpeg_convert @@ -122,7 +122,7 @@ def __init__( # Extract workspace_dir from kwargs to avoid passing it to BaseProcessor if "workspace_dir" in kwargs: workspace_dir = kwargs.pop("workspace_dir") - + super().__init__(**kwargs) self.input_audio_file_key = input_audio_file_key self.output_audio_file_key = output_audio_file_key @@ -141,13 +141,13 @@ def prepare(self): def process_dataset_entry(self, data_entry): audio_path = data_entry[self.input_audio_file_key] - + # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path - + # Debug print first file path if not hasattr(self, '_debug_printed'): logger.info(f"First audio_path from manifest: {audio_path}") @@ -167,4 +167,4 @@ def process_dataset_entry(self, data_entry): transformer.build(full_audio_path, converted_file) data_entry[self.output_audio_file_key] = converted_file - return [DataEntry(data=data_entry)] \ No newline at end of file + return [DataEntry(data=data_entry)] diff --git a/sdp/processors/manage_files/extract.py b/sdp/processors/manage_files/extract.py index 4c6126f6..8d975f83 100644 --- a/sdp/processors/manage_files/extract.py +++ b/sdp/processors/manage_files/extract.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tarfile import os +import tarfile from pathlib import Path from sdp.logging import logger -from sdp.processors.base_processor import DataEntry, BaseParallelProcessor +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry class ExtractTar(BaseParallelProcessor): @@ -42,15 +42,15 @@ class ExtractTar(BaseParallelProcessor): """ def __init__( - self, - field_to_tar_filepath: str, - extraction_dir: str, - remove_source_tar: bool = False, + self, + field_to_tar_filepath: str, + extraction_dir: str, + remove_source_tar: bool = False, skip_invalid_filepaths: bool = False, filepath_prefix_field: str = None, output_filepath_field: str = 'extracted', - get_extracted_filepaths: bool = False, - **kwargs + get_extracted_filepaths: bool = False, + **kwargs, ): super().__init__(**kwargs) self.field_to_tar_filepath = field_to_tar_filepath @@ -80,9 +80,7 @@ def process_dataset_entry(self, data_entry): else '' ) output_filepath = os.path.join( - self.extraction_dir, - output_filepath_prefix, - os.path.basename(tar_filepath).split('.')[0] + self.extraction_dir, output_filepath_prefix, os.path.basename(tar_filepath).split('.')[0] ) os.makedirs(output_filepath, exist_ok=True) @@ -101,9 +99,7 @@ def process_dataset_entry(self, data_entry): extracted_filepaths = [] if output_filepath is not None and self.get_extracted_filepaths: extraction_folder_path = Path(output_filepath) - extracted_filepaths = [ - str(file) for file in extraction_folder_path.rglob("*") if file.is_file() - ] + extracted_filepaths = [str(file) for file in extraction_folder_path.rglob("*") if file.is_file()] # Optionally remove the original tar archive after extraction if self.remove_source_tar: @@ -115,4 +111,4 @@ def process_dataset_entry(self, data_entry): else: data_entry[self.output_filepath_field] = output_filepath - return [DataEntry(data=data_entry)] \ No newline at end of file + return [DataEntry(data=data_entry)] diff --git a/sdp/processors/manage_files/remove.py b/sdp/processors/manage_files/remove.py index c982ef80..2ab62135 100644 --- a/sdp/processors/manage_files/remove.py +++ b/sdp/processors/manage_files/remove.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os import shutil -from pathlib import Path from collections import Counter +from pathlib import Path + from tqdm import tqdm -import itertools from tqdm.contrib.concurrent import process_map - from sdp.logging import logger -from sdp.processors.base_processor import DataEntry, BaseParallelProcessor +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry class RemoveFiles(BaseParallelProcessor): @@ -34,42 +34,37 @@ class RemoveFiles(BaseParallelProcessor): Args: filepath_field (str): The key in the data entry that holds the path to the file or directory to remove. - + drop_filepath_field (bool): Whether to remove the filepath field from the resulting data entry. Defaults to True. recursive (bool): Whether to recursively remove files from directories. Defaults to False. - + **kwargs: Additional arguments passed to the BaseParallelProcessor. - + Returns: A manifest where each entry is the same as the input, optionally without the filepath field, and with the file or directory at the specified path removed from disk. - + Example entry before processing:: - + { "id": "abc123", "path_to_remove": "/tmp/some_file.wav" } - + Example entry after processing (if `drop_filepath_field=True`):: - + { "id": "abc123" } """ - def __init__(self, - filepath_field: str, - drop_filepath_field: bool = True, - recursive: bool = False, - **kwargs): - + def __init__(self, filepath_field: str, drop_filepath_field: bool = True, recursive: bool = False, **kwargs): super().__init__(**kwargs) self.filepath_field = filepath_field self.drop_filepath_field = drop_filepath_field self.recursive = recursive - + def _count_files(self, data_entry): """ Count the number of files to be removed. @@ -81,7 +76,7 @@ def _count_files(self, data_entry): else: raise IsADirectoryError(f"Directory {filepath} is not empty and recursive is False") else: - file_counter = Counter({filepath.suffix : 1}) + file_counter = Counter({filepath.suffix: 1}) return file_counter def prepare(self): @@ -132,4 +127,4 @@ def process_dataset_entry(self, data_entry): data_entry.pop(self.filepath_field) # Wrap and return the modified entry - return [DataEntry(data=data_entry)] \ No newline at end of file + return [DataEntry(data=data_entry)] diff --git a/sdp/processors/modify_manifest/common.py b/sdp/processors/modify_manifest/common.py index c94c72bb..1225ed17 100644 --- a/sdp/processors/modify_manifest/common.py +++ b/sdp/processors/modify_manifest/common.py @@ -94,7 +94,6 @@ def process(self): subprocess.run(" ".join(process_args), shell=True) - class CombineSources(BaseParallelProcessor): """Can be used to create a single field from two alternative sources. @@ -474,7 +473,7 @@ class DropSpecifiedFields(BaseProcessor): """ A processor that removes specified fields from each data entry in the manifest. - This processor reads an input manifest line by line, drops the fields listed in `fields_to_drop` + This processor reads an input manifest line by line, drops the fields listed in `fields_to_drop` from each JSON entry, and writes the cleaned entries to the output manifest. Args: @@ -502,4 +501,4 @@ def process(self): # Create a new entry by excluding the specified fields new_line = {field: entry[field] for field in entry if field not in self.fields_to_drop} # Write the cleaned entry to the output manifest - fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") \ No newline at end of file + fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 35b4d5b0..5a3d255a 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import json import os import re from typing import Dict, List, Optional @@ -21,7 +22,6 @@ import torchaudio from docx import Document from tqdm import tqdm -import json from sdp.logging import logger from sdp.processors.base_processor import ( @@ -29,11 +29,11 @@ BaseProcessor, DataEntry, ) +from sdp.utils.apply_operators import evaluate_expression from sdp.utils.common import ffmpeg_convert from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces from sdp.utils.get_diff import get_diff_with_subs_grouped from sdp.utils.metrics_computation import get_wer -from sdp.utils.apply_operators import evaluate_expression class GetAudioDuration(BaseParallelProcessor): @@ -460,10 +460,10 @@ def __init__( super().__init__(**kwargs) if not regex_params_list and not regex_params_yaml: raise ValueError(f'One of `regex_params_list` or `regex_params_yaml` should be provided.') - + self.regex_params_list = regex_params_list if regex_params_yaml: - with open(regex_params_yaml, 'r') as regex_params_file: + with open(regex_params_yaml, 'r') as regex_params_file: self.regex_params_list = yaml.safe_load(regex_params_file) self.text_key = text_key @@ -555,6 +555,7 @@ def __init__( def prepare(self): from nemo_text_processing.text_normalization.normalize import Normalizer + try: self.normalizer = Normalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -603,7 +604,10 @@ def __init__( self.verbose = verbose def prepare(self): - from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer + from nemo_text_processing.inverse_text_normalization.inverse_normalize import ( + InverseNormalizer, + ) + try: self.inverse_normalizer = InverseNormalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -624,7 +628,7 @@ class CopyManifestData(BaseParallelProcessor): Args: copy_path (str): The destination directory where files will be copied. - source_filepath (str): The key in the manifest that contains the path to + source_filepath (str): The key in the manifest that contains the path to the file to be copied. Default: "audio_path". Returns: @@ -640,6 +644,7 @@ class CopyManifestData(BaseParallelProcessor): copy_path: ${workspace_dir}/consolidated_data source_filepath: "audio_filepath" """ + def __init__( self, copy_path: str, @@ -808,15 +813,16 @@ class GetWER(BaseParallelProcessor): """This processor calculates Word Error Rate (WER) between predicted text and ground truth text. It computes the WER for each entry in the manifest and adds the result as a new field. - + Args: text_key (str): Key for the ground truth text field in the manifest. Default: "text". pred_text_key (str): Key for the predicted text field in the manifest. Default: "pred_text". - + Returns: - The same data as in the input manifest with an additional 'wer' field containing + The same data as in the input manifest with an additional 'wer' field containing the calculated Word Error Rate between the specified text fields. """ + def __init__( self, text_key: str = "text", @@ -860,6 +866,7 @@ class MakeSentence(BaseParallelProcessor): end_symbol: "." make_uppercase: true """ + def __init__( self, text_key: str = "text", @@ -899,7 +906,14 @@ class ASRFileCheck(BaseProcessor): A manifest with corrupted audio files removed. """ - def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_dir: str = None, workspace_dir: str = None, **kwargs): + + def __init__( + self, + audio_filepath_key: str = "audio_filepath", + corrupted_audio_dir: str = None, + workspace_dir: str = None, + **kwargs, + ): """ Constructs the necessary attributes for the ASRFileCheck class. @@ -915,10 +929,12 @@ def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_d """ super().__init__(**kwargs) self.audio_filepath_key = audio_filepath_key - + if corrupted_audio_dir is None: - raise ValueError("corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files.") - + raise ValueError( + "corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files." + ) + self.corrupted_audio_dir = corrupted_audio_dir self.workspace_dir = workspace_dir self.failed_files = [] @@ -929,17 +945,17 @@ def process(self): This method reads through the manifest file, attempts to load each audio file using torchaudio, and moves corrupted files. A new manifest file is created with only the valid entries. - + Specific errors handled: - FileNotFoundError: File doesn't exist - RuntimeError: File format issues or codec problems - Other exceptions: General issues with file loading """ from sdp.logging import logger - + # Debug print to show workspace_dir logger.info(f"ASRFileCheck workspace_dir: {self.workspace_dir}") - + with open(self.input_manifest_file, 'r') as f: lines = f.readlines() @@ -953,22 +969,22 @@ def process(self): line = lines[idx] entry = json.loads(line) audio_path = entry[self.audio_filepath_key] - + # Debug print first file path if idx == 0: logger.info(f"First audio_path from manifest: {audio_path}") - + # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path - + # Debug print first full path if idx == 0: logger.info(f"First full_audio_path: {full_audio_path}") logger.info(f"Path exists: {os.path.exists(full_audio_path)}") - + try: # Attempt to load the audio file to check if it is corrupted torchaudio.load(full_audio_path) @@ -979,7 +995,7 @@ def process(self): except RuntimeError as e: logger.warning(f"Audio format error in {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -988,7 +1004,7 @@ def process(self): except Exception as e: logger.warning(f"Unknown error loading {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -1023,23 +1039,23 @@ class ListToEntries(BaseParallelProcessor): Raises: TypeError: If the specified list field is not of type list. ValueError: If the list items are not dictionaries and `output_field` is not provided. - + Returns: - A manifest where each entry corresponds to one item in the original list from the input entry. - This effectively transforms a single input entry containing a list of items into multiple standalone + A manifest where each entry corresponds to one item in the original list from the input entry. + This effectively transforms a single input entry containing a list of items into multiple standalone entries, each suitable for further dataset processing. .. admonition:: Example 1 (list of dicts) - + .. code-block:: yaml - + - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "segments" - + Input:: - + { "audio_filepath": "sample.wav", "segments": [ @@ -1064,19 +1080,19 @@ class ListToEntries(BaseParallelProcessor): "text": "World" } ] - + .. admonition:: Example 2 (list of primitives) - + .. code-block:: yaml - + - _target_: sdp.processors.ListToEntries input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json field_with_list: "text_chunks" output_field: "text" - + Input:: - + { "audio_filepath": "sample.wav", "text_chunks": [ @@ -1100,10 +1116,7 @@ class ListToEntries(BaseParallelProcessor): """ - def __init__(self, - field_with_list: str, - output_field: str = None, - **kwargs): + def __init__(self, field_with_list: str, output_field: str = None, **kwargs): super().__init__(**kwargs) self.field_with_list = field_with_list self.output_field = output_field @@ -1114,14 +1127,16 @@ def process_dataset_entry(self, data_entry): # Check that the target field is actually a list if not isinstance(data_entry[self.field_with_list], list): raise TypeError(f'Values of {self.field_with_list} field should be list type only: {data_entry}') - + # Remove the list field from the entry and get the list of items items_list = data_entry.pop(self.field_with_list) # If items are not dicts, output_field must be specified to store the item if not isinstance(items_list[0], dict) and not self.output_field: - raise ValueError(f'Type of items in items list `{self.field_with_list}` is not dict ({type(items_list[0])}). In this case `output_field` should be provided.') - + raise ValueError( + f'Type of items in items list `{self.field_with_list}` is not dict ({type(items_list[0])}). In this case `output_field` should be provided.' + ) + # Expand the list into multiple entries for item in items_list: _entry = data_entry.copy() @@ -1129,7 +1144,7 @@ def process_dataset_entry(self, data_entry): # If item is a dict, merge its keys; otherwise, store it in `output_field` if isinstance(item, dict): _entry.update(item) - else: + else: _entry[self.output_field] = item _entry = DataEntry(_entry) @@ -1202,6 +1217,7 @@ class LambdaExpression(BaseParallelProcessor): str: A line-delimited JSON manifest, where each line is a processed entry. The result may contain fewer entries than the input if ``filter=True``. """ + def __init__( self, new_field: str, @@ -1223,12 +1239,12 @@ def process_dataset_entry(self, data_entry) -> List[DataEntry]: If `filter` is True, the entry is only retained if the expression evaluates to True. Otherwise, the result is stored in `new_field`. """ - value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) + value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) if self.filter: if value is not True: return [] - data_entry[self.new_field] = value + data_entry[self.new_field] = value return [DataEntry(data=data_entry)] def finalize(self, metrics): - super().finalize(metrics) \ No newline at end of file + super().finalize(metrics) diff --git a/sdp/processors/modify_manifest/data_to_dropbool.py b/sdp/processors/modify_manifest/data_to_dropbool.py index eeeebd1e..67b93fbe 100644 --- a/sdp/processors/modify_manifest/data_to_dropbool.py +++ b/sdp/processors/modify_manifest/data_to_dropbool.py @@ -75,7 +75,7 @@ def __init__( 'Operator must be one from the list: "lt" (less than), "le" (less than or equal to), "eq" (equal to), "ne" (not equal to), "ge" (greater than or equal to), "gt" (greater than)' ) - def process_dataset_entry(self, data_entry): + def process_dataset_entry(self, data_entry): input_value = data_entry[self.input_value_key] target = self.target_value if self.operator(input_value, target): diff --git a/sdp/processors/modify_manifest/make_letters_uppercase_after_period.py b/sdp/processors/modify_manifest/make_letters_uppercase_after_period.py index 10ccd57e..25da47c6 100644 --- a/sdp/processors/modify_manifest/make_letters_uppercase_after_period.py +++ b/sdp/processors/modify_manifest/make_letters_uppercase_after_period.py @@ -35,7 +35,10 @@ class MakeLettersUppercaseAfterPeriod(BaseParallelProcessor): """ def __init__( - self, punctuation=".!?", text_key: str = "text", **kwargs, + self, + punctuation=".!?", + text_key: str = "text", + **kwargs, ): super().__init__(**kwargs) self.punctuation = punctuation diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 4359f320..ae8356ae 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -44,7 +44,7 @@ class ASRInference(BaseProcessor): def __init__( self, - pretrained_model: Optional[str]=None, + pretrained_model: Optional[str] = None, batch_size: int = 32, **kwargs, ): @@ -75,4 +75,4 @@ def process(self): f"batch_size={self.batch_size} ", shell=True, check=True, - ) \ No newline at end of file + ) diff --git a/sdp/processors/nemo/estimate_bandwidth.py b/sdp/processors/nemo/estimate_bandwidth.py index 38b261e7..0909dbd1 100644 --- a/sdp/processors/nemo/estimate_bandwidth.py +++ b/sdp/processors/nemo/estimate_bandwidth.py @@ -1,6 +1,7 @@ +from pathlib import Path + import librosa import numpy as np -from pathlib import Path from sdp.processors.base_processor import BaseParallelProcessor, DataEntry diff --git a/sdp/processors/nemo/speech_to_text_with_vad.py b/sdp/processors/nemo/speech_to_text_with_vad.py index 6fdd183d..54158306 100644 --- a/sdp/processors/nemo/speech_to_text_with_vad.py +++ b/sdp/processors/nemo/speech_to_text_with_vad.py @@ -57,25 +57,23 @@ import contextlib import json import os - import time -from dataclasses import dataclass, is_dataclass, field +from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import Callable, Optional import torch import torch.amp import yaml -from omegaconf import DictConfig, OmegaConf -from torch.profiler import ProfilerActivity, profile, record_function -from tqdm import tqdm - from nemo.collections.asr.data import feature_to_text_dataset from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel, EncDecClassificationModel from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.asr.parts.utils.manifest_utils import ( + read_manifest, + write_manifest, +) from nemo.collections.asr.parts.utils.vad_utils import ( generate_overlap_vad_seq, generate_vad_segment_table, @@ -85,6 +83,9 @@ ) from nemo.core.config import hydra_runner from nemo.utils import logging +from omegaconf import DictConfig, OmegaConf +from torch.profiler import ProfilerActivity, profile, record_function +from tqdm import tqdm @dataclass @@ -99,9 +100,9 @@ class InferenceConfig: use_rttm: bool = True # whether to use RTTM rttm_mode: str = "mask" # how to use RTTM files, choices=[`mask`, `drop`] feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults - normalize: Optional[str] = ( - "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`] - ) + normalize: Optional[ + str + ] = "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`] normalize_type: str = "per_feature" # how to determine mean and std used for normalization normalize_audio_db: Optional[float] = None # set to normalize RMS DB of audio before extracting audio features @@ -136,7 +137,6 @@ class InferenceConfig: @hydra_runner(config_name="InferenceConfig", schema=InferenceConfig) def main(cfg): - if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) @@ -171,7 +171,6 @@ def record_fn(*args, **kwargs): with profile_fn( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True ) as prof: - input_manifest_file = extract_audio_features(input_manifest_file, cfg, record_fn) if cfg.vad_model is not None: @@ -190,7 +189,6 @@ def record_fn(*args, **kwargs): def prepare_inference_manifest(cfg: DictConfig) -> str: - if cfg.audio_dir is not None and cfg.manifest_filepath is None: manifest_data = [] for audio_file in Path(cfg.audio_dir).glob(f"**/*.{cfg.audio_type}"): diff --git a/sdp/processors/nemo/transcribe_speech.py b/sdp/processors/nemo/transcribe_speech.py index bb04047b..7504f80b 100644 --- a/sdp/processors/nemo/transcribe_speech.py +++ b/sdp/processors/nemo/transcribe_speech.py @@ -22,12 +22,17 @@ import pytorch_lightning as pl import torch -from omegaconf import OmegaConf, open_dict - -from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel +from nemo.collections.asr.models import ( + EncDecCTCModel, + EncDecHybridRNNTCTCModel, + EncDecMultiTaskModel, +) from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig -from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig +from nemo.collections.asr.parts.submodules.multitask_decoding import ( + MultiTaskDecoding, + MultiTaskDecodingConfig, +) from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -40,6 +45,7 @@ ) from nemo.core.config import hydra_runner from nemo.utils import logging +from omegaconf import OmegaConf, open_dict """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. @@ -106,7 +112,6 @@ @dataclass class ModelChangeConfig: - # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = ConformerChangeConfig() @@ -414,4 +419,4 @@ def autocast(dtype=None): if __name__ == '__main__': - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/sdp/processors/toloka/accept_if.py b/sdp/processors/toloka/accept_if.py index 8472f601..bd4bb1c9 100644 --- a/sdp/processors/toloka/accept_if.py +++ b/sdp/processors/toloka/accept_if.py @@ -23,11 +23,12 @@ try: import toloka.client import toloka.client.project.template_builder + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - + from tqdm import tqdm @@ -49,7 +50,7 @@ class AcceptIfWERLess(BaseParallelProcessor): Returns: A manifest with accepted assignments from Toloka based on the WER threshold. - + Example: .. code-block:: yaml @@ -60,7 +61,7 @@ class AcceptIfWERLess(BaseParallelProcessor): input_pool_file: ${workspace_dir}/taskpool.json threshold: 50 """ - + def __init__( self, input_data_file: str, @@ -156,4 +157,3 @@ def process(self): accepted += 1 logger.info(f"Number of accepted task suits: {accepted} of {len(big_dict)}") - diff --git a/sdp/processors/toloka/create_pool.py b/sdp/processors/toloka/create_pool.py index 9948cef4..0f7442b3 100644 --- a/sdp/processors/toloka/create_pool.py +++ b/sdp/processors/toloka/create_pool.py @@ -22,11 +22,11 @@ try: import toloka.client import toloka.client.project.template_builder + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - class CreateTolokaPool(BaseParallelProcessor): @@ -42,6 +42,7 @@ class CreateTolokaPool(BaseParallelProcessor): Returns: A newly created pool on the Toloka platform, configured and ready for task assignment. """ + def __init__( self, lang: str = 'HY', @@ -59,11 +60,11 @@ def __init__( self.API_KEY = os.getenv('TOLOKA_API_KEY') if not self.API_KEY: raise ValueError("TOLOKA_API_KEY environment variable is not set") - + self.platform = os.getenv('TOLOKA_PLATFORM') if not self.platform: raise ValueError("TOLOKA_PLATFORM environment variable is not set") - + # Project ID will be read from the input manifest file in process_dataset_entry self.project_id = None self.lang = lang @@ -86,7 +87,7 @@ def process_dataset_entry(self, data_entry): list A list containing a DataEntry object with the new pool ID if successful, or an empty list if failed. """ - + if self.toloka_available != True: logger.warning("Toloka is currently not supported. CreatePool processor functionality will be limited.") diff --git a/sdp/processors/toloka/create_project.py b/sdp/processors/toloka/create_project.py index bf8ece19..28a3ba58 100644 --- a/sdp/processors/toloka/create_project.py +++ b/sdp/processors/toloka/create_project.py @@ -21,11 +21,11 @@ try: import toloka.client import toloka.client.project.template_builder + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - class CreateTolokaProject(BaseParallelProcessor): @@ -43,7 +43,7 @@ class CreateTolokaProject(BaseParallelProcessor): Returns: A project created on the Toloka platform, configured and ready for task and pool setup. """ - + def __init__( self, project_name: str, @@ -55,11 +55,11 @@ def __init__( self.API_KEY = os.getenv('TOLOKA_API_KEY') if not self.API_KEY: raise ValueError("TOLOKA_API_KEY environment variable is not set") - + self.platform = os.getenv('TOLOKA_PLATFORM') if not self.platform: raise ValueError("TOLOKA_PLATFORM environment variable is not set") - + self.project_name = project_name self.project_description = project_description self.project_instructions = project_instructions @@ -77,7 +77,9 @@ def process(self): """ logger.info("Processing Toloka project creation...") if self.toloka_availabe != True: - logger.warning("Toloka is currently not supported. CreateTolokaProject processor functionality will be limited.") + logger.warning( + "Toloka is currently not supported. CreateTolokaProject processor functionality will be limited." + ) toloka_client = toloka.client.TolokaClient(self.API_KEY, self.platform) @@ -128,4 +130,3 @@ def process(self): fout.write(json.dumps(data) + "\n") logger.info("Project created successfully: Project ID - {}".format(created_project.id)) - diff --git a/sdp/processors/toloka/create_sentence_set.py b/sdp/processors/toloka/create_sentence_set.py index 8a86afb6..9735d097 100644 --- a/sdp/processors/toloka/create_sentence_set.py +++ b/sdp/processors/toloka/create_sentence_set.py @@ -32,6 +32,7 @@ class CreateSentenceSet(BaseParallelProcessor): Returns: A list of `DataEntry` objects, each containing a single extracted sentence. """ + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/sdp/processors/toloka/create_task_set.py b/sdp/processors/toloka/create_task_set.py index 3957091f..638d5eab 100644 --- a/sdp/processors/toloka/create_task_set.py +++ b/sdp/processors/toloka/create_task_set.py @@ -19,18 +19,16 @@ from sdp.logging import logger from sdp.processors.base_processor import BaseParallelProcessor, DataEntry - try: import toloka.client import toloka.client.project.template_builder + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - - class CreateTolokaTaskSet(BaseParallelProcessor): """Creates a set of tasks in a Toloka pool based on user-provided configurations and input data. @@ -45,6 +43,7 @@ class CreateTolokaTaskSet(BaseParallelProcessor): Returns: A set of tasks created and uploaded to the specified Toloka pool. """ + def __init__( self, input_data_file: str, @@ -58,19 +57,16 @@ def __init__( self.limit = limit self.pool_id = None self.toloka_available = TOLOKA_AVAILABLE - - - # Get API key and platform from environment variables self.API_KEY = os.getenv('TOLOKA_API_KEY') if not self.API_KEY: raise ValueError("TOLOKA_API_KEY environment variable is not set") - + self.platform = os.getenv('TOLOKA_PLATFORM') if not self.platform: raise ValueError("TOLOKA_PLATFORM environment variable is not set") - + self.toloka_client = None def prepare(self): @@ -133,7 +129,7 @@ def process(self): Creates Toloka tasks based on manifest data and adds them to the specified pool. This method reads the input manifest, creates task objects for Toloka, and submits - them to the specified pool. It also writes the manifest data to an output file after + them to the specified pool. It also writes the manifest data to an output file after tasks have been created. Raises: diff --git a/sdp/processors/toloka/download_responses.py b/sdp/processors/toloka/download_responses.py index aa2563cf..a58e2e9a 100644 --- a/sdp/processors/toloka/download_responses.py +++ b/sdp/processors/toloka/download_responses.py @@ -20,12 +20,11 @@ try: import toloka.client + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - - class GetTolokaResults(BaseParallelProcessor): @@ -47,6 +46,7 @@ class GetTolokaResults(BaseParallelProcessor): Returns: A set of task results from Toloka, stored in the specified output directory. """ + def __init__( self, input_data_file: str, @@ -120,7 +120,9 @@ def prepare(self): This method loads necessary configurations and initializes the Toloka client to interact with Toloka's API. """ if self.toloka_available != True: - logger.warning("Toloka is currently not supported. DownloadResponses processor functionality will be limited.") + logger.warning( + "Toloka is currently not supported. DownloadResponses processor functionality will be limited." + ) if not self.API_KEY or not self.platform or not self.pool_id: try: @@ -246,4 +248,3 @@ def process_dataset_entry(self, data_entry): } return [DataEntry(data=task_info)] - diff --git a/sdp/processors/toloka/reject_if.py b/sdp/processors/toloka/reject_if.py index 182c3e86..5948da8f 100644 --- a/sdp/processors/toloka/reject_if.py +++ b/sdp/processors/toloka/reject_if.py @@ -21,11 +21,12 @@ try: import toloka.client import toloka.client.project.template_builder + TOLOKA_AVAILABLE = True except ImportError: TOLOKA_AVAILABLE = False toloka = None - + from docx import Document from tqdm import tqdm @@ -34,7 +35,7 @@ class RejectIfBanned(BaseParallelProcessor): """Rejects Toloka assignments if the user is banned. - This class connects to Toloka, checks the user’s ban status, and rejects any assignments + This class connects to Toloka, checks the user’s ban status, and rejects any assignments from users who are identified as banned. Args: @@ -48,6 +49,7 @@ class RejectIfBanned(BaseParallelProcessor): Returns: A list of rejected assignments for users who are banned. """ + def __init__( self, input_data_file: str, @@ -160,4 +162,3 @@ def process(self): print("REJECTION LIST -------------------------", reject_list) for assignment_id in tqdm(reject_list, desc="Rejecting assignments"): self.toloka_client.reject_assignment(assignment_id=assignment_id, public_comment='Bad quality of audio.') - diff --git a/sdp/processors/tts/__init__.py b/sdp/processors/tts/__init__.py index 863522ea..eac18d52 100644 --- a/sdp/processors/tts/__init__.py +++ b/sdp/processors/tts/__init__.py @@ -1,2 +1 @@ - -from sdp.processors.tts.pyannote import PyAnnoteDiarizationAndOverlapDetection \ No newline at end of file +from sdp.processors.tts.pyannote import PyAnnoteDiarizationAndOverlapDetection diff --git a/sdp/processors/tts/merge_alignment_diarization.py b/sdp/processors/tts/merge_alignment_diarization.py index c23adbda..155db273 100644 --- a/sdp/processors/tts/merge_alignment_diarization.py +++ b/sdp/processors/tts/merge_alignment_diarization.py @@ -15,6 +15,7 @@ from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest, save_manifest + class MergeAlignmentDiarization(BaseProcessor): """This processor merges alignment and diarization information from a manifest file. @@ -35,8 +36,8 @@ class MergeAlignmentDiarization(BaseProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_merged.json """ - def __init__(self, - **kwargs): + + def __init__(self, **kwargs): super().__init__(**kwargs) def process(self): @@ -44,8 +45,8 @@ def process(self): # Manifest here needs to contain both paths to alignment files and 'segments' # from pyannote. We identify all the words that belong in each pyannote segment - # and join them together. - + # and join them together. + for metadata in manifest: alignment = metadata['alignment'] segments = metadata['segments'] @@ -74,7 +75,9 @@ def process(self): # Check overlap with the next segment, if it exists if i < len(segments) - 1: next_segment = segments[i + 1] - next_overlap = max(0, min(word_end, next_segment['end']) - max(word_start, next_segment['start'])) + next_overlap = max( + 0, min(word_end, next_segment['end']) - max(word_start, next_segment['start']) + ) else: next_overlap = 0 @@ -87,7 +90,7 @@ def process(self): else: # If no overlap with current or next segment, increment to avoid infinite loop last_word_idx += 1 - + # If we are at the last word, break if last_word_idx == len(alignment): break @@ -96,4 +99,3 @@ def process(self): segment['words'] = words_in_segment save_manifest(manifest, self.output_manifest_file) - diff --git a/sdp/processors/tts/metrics.py b/sdp/processors/tts/metrics.py index f7fe681f..d2177755 100644 --- a/sdp/processors/tts/metrics.py +++ b/sdp/processors/tts/metrics.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import librosa import math + +import librosa import numpy as np +import torch +import torchaudio +import torchaudio.functional as F +from torchaudio.pipelines import SQUIM_OBJECTIVE from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest, save_manifest -import torch -import torchaudio -import torchaudio.functional as F -from torchaudio.pipelines import SQUIM_OBJECTIVE - class TorchSquimObjectiveQualityMetricsProcessor(BaseProcessor): """This processor calculates Squim quality metrics for audio files. @@ -35,13 +35,13 @@ class TorchSquimObjectiveQualityMetricsProcessor(BaseProcessor): PESQ (Perceptual Evaluation of Speech Quality) A measure of overall quality for speech (originally designed to detect codec distortions but highly correlated to all kinds of distortion. - + STOI (Short-Time Objective Intelligibility) - A measure of speech intelligibility, basically measures speech envelope integrity. + A measure of speech intelligibility, basically measures speech envelope integrity. A STOI value of 1.0 means 100% of the speech being evaluated is intelligible on average. SI-SDR (Scale-Invariant Signal-to-Distortion Ratio) - A measure of how strong the speech signal is vs. all the distortion present in the audio, in decibels. + A measure of how strong the speech signal is vs. all the distortion present in the audio, in decibels. 0 dB means the energies of speech and distortion are the same. A value between 15-20 dB is what is considered "clean enough" speech in general. Args: @@ -58,13 +58,14 @@ class TorchSquimObjectiveQualityMetricsProcessor(BaseProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_squim.json """ + def __init__(self, device: str = "cuda", **kwargs): super().__init__(**kwargs) if not torch.cuda.is_available(): - device="cpu" + device = "cpu" logger.warning("CUDA is not available, using CPU") - + if device == "cuda": self.model = SQUIM_OBJECTIVE.get_model().cuda() else: @@ -79,14 +80,14 @@ def process(self): info = torchaudio.info(metadata['resampled_audio_filepath']) sr = info.sample_rate - try: + try: audio, _ = librosa.load(path=metadata['resampled_audio_filepath'], sr=sr) except Exception as ex: logger.info(f"Failed to load {metadata['resampled_audio_filepath']}, exception={ex}") continue for segment in metadata["segments"]: - if ("text" in segment and segment["text"].strip() == "") or (segment["speaker"]=="no-speaker"): + if ("text" in segment and segment["text"].strip() == "") or (segment["speaker"] == "no-speaker"): continue start = segment["start"] end = segment["end"] @@ -95,9 +96,9 @@ def process(self): end = math.floor(end * sr) num_samples = end - start - y = audio[start: end] + y = audio[start:end] y = torch.from_numpy(y) - y = torch.unsqueeze(y, dim=0) # needed otherwise throws input size error + y = torch.unsqueeze(y, dim=0) # needed otherwise throws input size error if sr != 16000: y = F.resample(y, sr, 16000) @@ -122,10 +123,11 @@ def process(self): segment['metrics'] = metrics except Exception as e: torch.cuda.empty_cache() - logger.info('Failed to extract Squim metrics {} with frame_offset={} and num_frames={}'.format( - metadata['resampled_audio_filepath'], - start, - num_samples)) + logger.info( + 'Failed to extract Squim metrics {} with frame_offset={} and num_frames={}'.format( + metadata['resampled_audio_filepath'], start, num_samples + ) + ) continue results.append(metadata) @@ -156,13 +158,14 @@ class BandwidthEstimationProcessor(BaseProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_with_bandwidth.json """ + def __init__( self, n_fft: int = 512, stride_seconds: float = 0.01, top_db: float = 100.0, frequency_threshold: float = -50.0, - **kwargs + **kwargs, ): super().__init__(**kwargs) self.n_fft = n_fft @@ -172,14 +175,14 @@ def __init__( def _estimate_bandwidth(self, audio, sample_rate): """Estimates the bandwidth of an audio signal. - + This method calculates the power spectrogram of the audio signal and determines the bandwidth based on a frequency threshold. - + Args: audio (np.ndarray): The audio signal to estimate the bandwidth of sample_rate (int): The sample rate of the audio signal - + Returns: int: The estimated bandwidth of the audio signal """ @@ -208,19 +211,19 @@ def process(self): for metadata in tqdm(manifest): audio_filepath = metadata['audio_filepath'] - try: + try: audio, sample_rate = librosa.load(path=audio_filepath, sr=None) except Exception as ex: logger.info(f"Failed to load {audio_filepath}, exception={ex}") continue for segment in metadata['segments']: - if ("text" in segment and segment["text"].strip() == "") or (segment["speaker"]=="no-speaker"): + if ("text" in segment and segment["text"].strip() == "") or (segment["speaker"] == "no-speaker"): continue start = segment['start'] end = segment['end'] - audio_segment = audio[int(start*sample_rate): int(end*sample_rate)] + audio_segment = audio[int(start * sample_rate) : int(end * sample_rate)] bandwidth = self._estimate_bandwidth(audio=audio_segment, sample_rate=sample_rate) @@ -235,4 +238,3 @@ def process(self): results.append(metadata) save_manifest(results, self.output_manifest_file) - diff --git a/sdp/processors/tts/nemo_asr_align.py b/sdp/processors/tts/nemo_asr_align.py index 9a71c476..dcc322bd 100644 --- a/sdp/processors/tts/nemo_asr_align.py +++ b/sdp/processors/tts/nemo_asr_align.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import nemo.collections.asr as nemo_asr import omegaconf import torch import torchaudio -import nemo.collections.asr as nemo_asr + from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest, save_manifest + class NeMoASRAligner(BaseProcessor): """This processor aligns text and audio using NeMo ASR models. @@ -53,51 +55,52 @@ class NeMoASRAligner(BaseProcessor): output_manifest_file: ${workspace_dir}/manifest_aligned.json parakeet: True """ - def __init__(self, - model_name="nvidia/parakeet-tdt_ctc-1.1b", - model_path=None, - min_len: float = 0.1, - max_len: float = 40, - parakeet: bool = True, - ctc: bool = False, - batch_size: int = 32, - num_workers: int = 10, - split_batch_size: int = 5000, - timestamp_type: str = "word", - infer_segment_only: bool = False, - device: str = "cuda", - **kwargs): + + def __init__( + self, + model_name="nvidia/parakeet-tdt_ctc-1.1b", + model_path=None, + min_len: float = 0.1, + max_len: float = 40, + parakeet: bool = True, + ctc: bool = False, + batch_size: int = 32, + num_workers: int = 10, + split_batch_size: int = 5000, + timestamp_type: str = "word", + infer_segment_only: bool = False, + device: str = "cuda", + **kwargs, + ): super().__init__(**kwargs) if model_path is not None: self.asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path) else: self.asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) - + if not torch.cuda.is_available(): device = "cpu" logger.warning("CUDA is not available, using CPU") - + self.asr_model.to(device) # Configuring attention to work with longer files - self.asr_model.change_attention_model( - self_attention_model="rel_pos_local_attn", att_context_size=[128, 128] - ) + self.asr_model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[128, 128]) self.asr_model.change_subsampling_conv_chunking_factor(1) self.min_len = min_len self.max_len = max_len - self.parakeet = parakeet # if model type is parakeet or not, determines time stride - self.ctc = ctc # if decoder type is ctc or not, determines timestamp substraction + self.parakeet = parakeet # if model type is parakeet or not, determines time stride + self.ctc = ctc # if decoder type is ctc or not, determines timestamp substraction self.timestamp_type = timestamp_type self.infer_segment_only = infer_segment_only - cfg = self.asr_model.cfg.decoding + cfg = self.asr_model.cfg.decoding with omegaconf.open_dict(cfg): - cfg['compute_timestamps']=True - cfg['preserve_alignments']=True + cfg['compute_timestamps'] = True + cfg['preserve_alignments'] = True if ctc: cfg.strategy = "greedy_batch" else: - cfg['rnnt_timestamp_type'] = self.timestamp_type + cfg['rnnt_timestamp_type'] = self.timestamp_type self.asr_model.change_decoding_strategy(decoding_cfg=cfg) # set batch size @@ -119,7 +122,7 @@ def get_alignments_text(self, hypotheses): - list: List of dictionaries with word alignments (word, start, end) - str: The transcribed text """ - timestamp_dict = hypotheses.timestep # extract timesteps from hypothesis of first (and only) audio file + timestamp_dict = hypotheses.timestep # extract timesteps from hypothesis of first (and only) audio file # For a FastConformer model, you can display the word timestamps as follows: # 80ms is duration of a timestep at output of the Conformer @@ -134,8 +137,8 @@ def get_alignments_text(self, hypotheses): for stamp in word_timestamps: if self.ctc: start = stamp['start_offset'] * time_stride - end = stamp['end_offset'] * time_stride - else: # if rnnt or tdt decoder + end = stamp['end_offset'] * time_stride + else: # if rnnt or tdt decoder start = max(0, stamp['start_offset'] * time_stride - 0.08) end = max(0, stamp['end_offset'] * time_stride - 0.08) @@ -145,8 +148,7 @@ def get_alignments_text(self, hypotheses): text = hypotheses.text text = text.replace("⁇", "") return alignments, text - - + def _prepare_metadata_batch(self, metadata_batch): """Prepare audio data and segment mapping for a batch of metadata files. @@ -166,7 +168,7 @@ def _prepare_metadata_batch(self, metadata_batch): for segment_idx, segment in enumerate(metadata['segments']): duration = segment['end'] - segment['start'] - if duration >= self.min_len and segment['speaker']!='no-speaker': + if duration >= self.min_len and segment['speaker'] != 'no-speaker': start = int(segment['start'] * sr) end = int(segment['end'] * sr) audio_segment = audio[:, start:end].squeeze(0) @@ -176,7 +178,6 @@ def _prepare_metadata_batch(self, metadata_batch): return all_segments, segment_indices - def process(self): """Process the input manifest file to generate word alignments and transcriptions. @@ -191,49 +192,49 @@ def process(self): Results are saved in JSONL format with alignments and transcriptions added to the original metadata. """ manifest = load_manifest(self.input_manifest_file) - + results = [] if not self.infer_segment_only: transcribe_manifest = [] for data in manifest: - if (('split_filepaths' in data and data['split_filepaths'] is None) or ('split_filepaths' not in data)) and data['duration'] > self.min_len: + if ( + ('split_filepaths' in data and data['split_filepaths'] is None) or ('split_filepaths' not in data) + ) and data['duration'] > self.min_len: transcribe_manifest.append(data) else: data['text'] = '' data['alignment'] = [] results.append(data) - files = [x['resampled_audio_filepath'] for x in transcribe_manifest] for i in range(0, len(files), self.split_batch_size): - batch = files[i:i + self.split_batch_size] + batch = files[i : i + self.split_batch_size] with torch.no_grad(): hypotheses_list = self.asr_model.transcribe(batch, override_config=self.override_cfg) # if hypotheses form a tuple (from RNNT), extract just "best" hypotheses if type(hypotheses_list) == tuple and len(hypotheses_list) == 2: hypotheses_list = hypotheses_list[0] - - metadatas = transcribe_manifest[i:i + self.split_batch_size] + + metadatas = transcribe_manifest[i : i + self.split_batch_size] for idx, metadata in enumerate(metadatas): - hypotheses = hypotheses_list[idx] + hypotheses = hypotheses_list[idx] alignments, text = self.get_alignments_text(hypotheses) metadata['text'] = text - metadata['alignment']= alignments + metadata['alignment'] = alignments results.append(metadata) else: for i in range(0, len(manifest), self.split_batch_size): - metadata_batch = manifest[i:i + self.split_batch_size] + metadata_batch = manifest[i : i + self.split_batch_size] all_segments, segment_indices = self._prepare_metadata_batch(metadata_batch) - + try: with torch.no_grad(): hypotheses_list = self.asr_model.transcribe(all_segments, override_config=self.override_cfg) except Exception as e: - files_list = [ item['resampled_audio_filepath'] for item in metadata_batch ] + files_list = [item['resampled_audio_filepath'] for item in metadata_batch] raise ValueError(f"Exception occurred for audio filepath list: {files_list}, Error is : {str(e)}") - if type(hypotheses_list) == tuple and len(hypotheses_list) == 2: hypotheses_list = hypotheses_list[0] @@ -244,8 +245,8 @@ def process(self): for word in alignments: word['start'] = round(word['start'] + segment['start'], 3) word['end'] = round(word['end'] + segment['start'], 3) - - segment['words']= alignments + + segment['words'] = alignments results.extend(metadata_batch) diff --git a/sdp/processors/tts/prepare_tts_segments.py b/sdp/processors/tts/prepare_tts_segments.py index 624e2124..a2067ae4 100644 --- a/sdp/processors/tts/prepare_tts_segments.py +++ b/sdp/processors/tts/prepare_tts_segments.py @@ -17,6 +17,7 @@ from sdp.processors.base_processor import BaseParallelProcessor, DataEntry from sdp.utils.common import load_manifest + class PrepareTTSSegmentsProcessor(BaseParallelProcessor): """This processor merges adjacent segments from the same speaker and splits segments to have a complete utterance. @@ -44,23 +45,26 @@ class PrepareTTSSegmentsProcessor(BaseParallelProcessor): min_duration: 5 max_duration: 20 """ - def __init__(self, - min_duration: float = 5, - max_duration: float = 20, - max_pause: float = 2, - terminal_punct_marks: str = ".!?。??!。", - punctuation_split_only: bool = False, - **kwargs): + + def __init__( + self, + min_duration: float = 5, + max_duration: float = 20, + max_pause: float = 2, + terminal_punct_marks: str = ".!?。??!。", + punctuation_split_only: bool = False, + **kwargs + ): super().__init__(**kwargs) self.min_duration = min_duration self.max_duration = max_duration self.max_pause = max_pause self.terminal_punct_marks = terminal_punct_marks self.punctuation_split_only = punctuation_split_only - + def read_manifest(self): - ''' Reads metadata from JSONL file in the input manifest - and converts it to data entries ''' + '''Reads metadata from JSONL file in the input manifest + and converts it to data entries''' dataset_entries = load_manifest(self.input_manifest_file, encoding="utf8") @@ -72,7 +76,11 @@ def get_words_list_from_all_segments(self, segments): """ words = [] for segment in segments: - if ("text" in segment and segment["text"].strip() == "") or (segment["speaker"]=="no-speaker") or (not "text" in segment): + if ( + ("text" in segment and segment["text"].strip() == "") + or (segment["speaker"] == "no-speaker") + or (not "text" in segment) + ): continue if 'words' in segment: @@ -93,19 +101,21 @@ def get_words_list_from_all_segments(self, segments): logger.info('Found no words in segment') return words - + def is_valid_segment(self, segment): """ This method checks if the segment is valid """ - if len(segment["words"]) ==1 and segment["words"][0]["end"] - segment["words"][0]["start"] > self.max_duration: + if ( + len(segment["words"]) == 1 + and segment["words"][0]["end"] - segment["words"][0]["start"] > self.max_duration + ): return False sentence = " ".join([word["word"] for word in segment["words"]]) if sentence: return True return False - - + def split_segment_by_duration(self, segment): """ This method splits the segment by duration, pauses, and bandwidth changes @@ -118,7 +128,7 @@ def split_segment_by_duration(self, segment): "words": [], } segments = [] - + for word in words: if not current_segment["words"]: current_segment = { @@ -128,7 +138,7 @@ def split_segment_by_duration(self, segment): "words": [word], } continue - + # break the current segment if the duration is greater than the max duration and start a new segment if (word["end"] - current_segment["start"]) > self.max_duration: if self.is_valid_segment(current_segment): @@ -140,9 +150,11 @@ def split_segment_by_duration(self, segment): "words": [word], } continue - + # break the current segment if the pause is greater than the max pause and start a new segment - if (word["start"] - current_segment["end"] > self.max_pause) and (current_segment["end"] - current_segment["start"] >= self.min_duration): + if (word["start"] - current_segment["end"] > self.max_pause) and ( + current_segment["end"] - current_segment["start"] >= self.min_duration + ): if self.is_valid_segment(current_segment): segments.append(current_segment) current_segment = { @@ -152,9 +164,11 @@ def split_segment_by_duration(self, segment): "words": [word], } continue - + # break the current segment if the bandwidth is different and start a new segment - if (current_segment['words'] and word['bandwidth']!=current_segment['words'][-1]['bandwidth'] ) and (current_segment["end"] - current_segment["start"] >= self.min_duration): + if (current_segment['words'] and word['bandwidth'] != current_segment['words'][-1]['bandwidth']) and ( + current_segment["end"] - current_segment["start"] >= self.min_duration + ): if self.is_valid_segment(current_segment): segments.append(current_segment) current_segment = { @@ -167,14 +181,14 @@ def split_segment_by_duration(self, segment): current_segment["words"].append(word) current_segment["end"] = word["end"] - + # add the last segment if it is valid if current_segment["words"]: if self.is_valid_segment(current_segment): segments.append(current_segment) - + return segments - + def split_segment_by_punctuation(self, segment): """ This method splits the given single speaker segment by punctuation marks, if no punctuation marks are found then it splits the segment by duration. @@ -185,9 +199,7 @@ def split_segment_by_punctuation(self, segment): words = segment["words"] # get the punctuation split points split_points = [ - i - for i, word in enumerate(words) - if word["word"] and word["word"][-1] in self.terminal_punct_marks + i for i, word in enumerate(words) if word["word"] and word["word"][-1] in self.terminal_punct_marks ] segments = [] # if no punctuation marks, split the segment by duration @@ -207,7 +219,11 @@ def split_segment_by_punctuation(self, segment): if current_duration < self.min_duration: # merge with the next split points until the maximum duration is reached next_end = current_end + 1 - while next_end < len(split_points) and words[split_points[next_end]]["end"] - words[split_points[current_start]]["start"] <= self.max_duration: + while ( + next_end < len(split_points) + and words[split_points[next_end]]["end"] - words[split_points[current_start]]["start"] + <= self.max_duration + ): next_end += 1 if next_end > current_end + 1: @@ -221,17 +237,17 @@ def split_segment_by_punctuation(self, segment): new_split_points.append(split_points[current_end]) current_start = current_end + 1 current_end = current_end + 1 - + # now split the segment at the new split points # if the duration of the segment is greater than the max duration, split the segment by duration start = 0 for end in new_split_points: duration = words[end]["end"] - words[start]["start"] sub_segment = { - "speaker": segment["speaker"], - "start": words[start]["start"], - "end": words[end]["end"], - "words": words[start : end + 1], + "speaker": segment["speaker"], + "start": words[start]["start"], + "end": words[end]["end"], + "words": words[start : end + 1], } if duration <= self.max_duration: if self.is_valid_segment(sub_segment): @@ -239,7 +255,7 @@ def split_segment_by_punctuation(self, segment): else: segments.extend(self.split_segment_by_duration(sub_segment)) start = end + 1 - + # remaining clause in a new segment if start < len(words): remaining_segment = { @@ -253,7 +269,6 @@ def split_segment_by_punctuation(self, segment): return segments def add_new_segments_to_metadata(self, metadata, new_segments): - segments = [] for new_segment in new_segments: @@ -262,7 +277,9 @@ def add_new_segments_to_metadata(self, metadata, new_segments): "start": new_segment["start"], "end": new_segment["end"], "text": " ".join(word["word"] for word in new_segment["words"]), - "words": [{"word": word["word"], "start": word["start"], "end": word["end"]} for word in new_segment["words"]], + "words": [ + {"word": word["word"], "start": word["start"], "end": word["end"]} for word in new_segment["words"] + ], "pesq_squim": [word["pesq_squim"] for word in new_segment["words"]], "stoi_squim": [word["stoi_squim"] for word in new_segment["words"]], "sisdr_squim": [word["sisdr_squim"] for word in new_segment["words"]], @@ -310,10 +327,9 @@ def process_dataset_entry(self, metadata: DataEntry): else: current_segment["words"].append(word) current_segment["end"] = word["end"] - + if current_segment["words"]: speaker_segments.append(current_segment) - # split the segments at the punctuation marks, pauses, and bandwidth changes for speaker_segment in speaker_segments: @@ -323,10 +339,8 @@ def process_dataset_entry(self, metadata: DataEntry): # add the new segments to the metadata self.add_new_segments_to_metadata(metadata, new_segments) - + else: logger.info('Found no segments in metadata for audio file: ', metadata['audio_filepath']) - - return [DataEntry(data=metadata)] - + return [DataEntry(data=metadata)] diff --git a/sdp/processors/tts/pyannote.py b/sdp/processors/tts/pyannote.py index 96a35ced..d3ba07a0 100644 --- a/sdp/processors/tts/pyannote.py +++ b/sdp/processors/tts/pyannote.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random -import os import logging +import os +import random from time import time + +import torch +import torchaudio from pyannote.audio import Pipeline from pyannote.audio.pipelines.utils.hook import ProgressHook from whisperx.audio import SAMPLE_RATE from whisperx.vad import load_vad_model, merge_chunks -import torch -import torchaudio from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest, save_manifest + def has_overlap(turn, overlaps): """Check if a given turn overlaps with any segment in the overlaps list. @@ -52,6 +54,7 @@ def has_overlap(turn, overlaps): break return turn_overlaps + class PyAnnoteDiarizationAndOverlapDetection(BaseProcessor): """This processor performs speaker diarization and overlap detection using PyAnnote. @@ -79,29 +82,27 @@ class PyAnnoteDiarizationAndOverlapDetection(BaseProcessor): output_manifest_file: ${workspace_dir}/manifest_diarized.json hf_token: ${hf_token} """ - - def __init__(self, - hf_token: str, - segmentation_batch_size: int = 128, - embedding_batch_size: int = 128, - min_length: float = 0.5, - max_length: float = 40, - device: str = "cuda", - **kwargs - ): + def __init__( + self, + hf_token: str, + segmentation_batch_size: int = 128, + embedding_batch_size: int = 128, + min_length: float = 0.5, + max_length: float = 40, + device: str = "cuda", + **kwargs, + ): super().__init__(**kwargs) - - self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", - use_auth_token=hf_token) + self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token) self.pipeline.segmentation_batch_size = segmentation_batch_size self.pipeline.embedding_batch_size = embedding_batch_size if not torch.cuda.is_available(): device = "cpu" logging.warning("CUDA is not available, using CPU") - + self.pipeline.to(torch.device(device)) self.min_length = min_length @@ -111,12 +112,9 @@ def __init__(self, self.vad_offset = 0.363 default_vad_options = {"vad_onset": self.vad_onset, "vad_offset": self.vad_offset} - self.vad_model = load_vad_model( - torch.device(device), use_auth_token=None, **default_vad_options - ) + self.vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) random.seed(42) - - + def get_vad_segments(self, audio, merge_max_length=3): """Get voice activity detection segments for the given audio. @@ -128,16 +126,16 @@ def get_vad_segments(self, audio, merge_max_length=3): list: List of VAD segments with start and end times """ vad_segments = self.vad_model( - { - "waveform": audio, - "sample_rate": SAMPLE_RATE, - } - ) - + { + "waveform": audio, + "sample_rate": SAMPLE_RATE, + } + ) + vad_segments = merge_chunks(vad_segments, merge_max_length, onset=self.vad_onset) - + return vad_segments - + def add_vad_segments(self, audio, fs, start, end, segments, speaker_id): """Add VAD segments for a given audio region to the segments list. @@ -155,7 +153,7 @@ def add_vad_segments(self, audio, fs, start, end, segments, speaker_id): """ segment_duration = end - start if segment_duration > self.max_length: - audio_seg = audio[: , int(start * fs): int(end * fs)] + audio_seg = audio[:, int(start * fs) : int(end * fs)] vad_segments = self.get_vad_segments(audio_seg) i = 0 n = len(vad_segments) @@ -172,14 +170,14 @@ def add_vad_segments(self, audio, fs, start, end, segments, speaker_id): segment_data_entry['start'] = start + start_seg segment_data_entry['end'] = start + end_seg segments.append(segment_data_entry) - i += 1 + i += 1 continue # Merge segments until the random duration is reached while i < n and (vad_segments[i]['end'] - start_seg) < random_duration: end_seg = vad_segments[i]['end'] # Extend the end time i += 1 - + segment_data_entry = {} segment_data_entry['speaker'] = speaker_id segment_data_entry['start'] = start + start_seg @@ -216,7 +214,7 @@ def process(self): for metadata in manifest: file_path = metadata['resampled_audio_filepath'] logger.info(file_path) - + s, fs = torchaudio.load(file_path) with ProgressHook() as hook: diarization = self.pipeline({'waveform': s, 'sample_rate': fs}, hook=hook) @@ -225,7 +223,7 @@ def process(self): # Due to a bug in PyAnnote-Audio, diarization might have timestamps longer than # the audio, so we crop it to the audio length # https://github.com/pyannote/pyannote-audio/issues/1611 - diarization.crop(0, len(s)/fs) + diarization.crop(0, len(s) / fs) # write in RTTM format logger.info("Writing {} turns to RTTM file".format(len(diarization._tracks))) @@ -244,7 +242,7 @@ def process(self): speaker_id = metadata['speaker_id'] + '_' + speaker else: raise ValueError('No speaker identifier in sample {}'.format(metadata['resampled_audio_filepath'])) - + if has_overlap(speech_turn, overlaps): segment_data_entry = {} segment_data_entry['speaker'] = speaker_id @@ -271,7 +269,7 @@ def process(self): # If there is any remaining audio after the last speaker segment if last_end_time < audio_duration: non_speaker_segments.append((last_end_time, audio_duration)) - + for start, end in non_speaker_segments: speaker_id = "no-speaker" current_start = start @@ -283,7 +281,7 @@ def process(self): segment_data_entry['end'] = current_end segments.append(segment_data_entry) current_start = current_end - + # Sort all segments by start time segments.sort(key=lambda x: x['start']) metadata['segments'] = segments @@ -292,4 +290,3 @@ def process(self): logger.info(f'Completed diarization in {(time()-start_time)/3600} hrs') save_manifest(results, self.output_manifest_file) - diff --git a/sdp/processors/tts/split.py b/sdp/processors/tts/split.py index 237c4e1e..dce6a209 100644 --- a/sdp/processors/tts/split.py +++ b/sdp/processors/tts/split.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sdp.processors.base_processor import BaseProcessor, DataEntry import json -import os -import torchaudio import math +import os from copy import deepcopy + +import torchaudio + +from sdp.processors.base_processor import BaseProcessor, DataEntry from sdp.utils.common import load_manifest, save_manifest + class SplitLongAudio(BaseProcessor): """This processor splits long audio files into smaller segments. @@ -43,12 +46,8 @@ class SplitLongAudio(BaseProcessor): output_manifest_file: ${workspace_dir}/manifest_split.json suggested_max_len: 3600 """ - def __init__(self, - suggested_max_len: float = 3600, - min_pause_len: float = 1.0, - min_len: float = 1.0, - **kwargs - ): + + def __init__(self, suggested_max_len: float = 3600, min_pause_len: float = 1.0, min_len: float = 1.0, **kwargs): super().__init__(**kwargs) self.suggested_max_len = suggested_max_len self.min_pause_len = min_pause_len @@ -107,24 +106,29 @@ def process(self): actual_splits = [] split_durations = [] for k, split in enumerate(splits): - split_filepath = os.path.join(path, filename[:-4] + '.{}_of_{}.wav'.format(k+1, 1+len(splits))) - split_end = math.ceil(split*sr) - if split_end-split_start > self.min_len * sr: + split_filepath = os.path.join(path, filename[:-4] + '.{}_of_{}.wav'.format(k + 1, 1 + len(splits))) + split_end = math.ceil(split * sr) + if split_end - split_start > self.min_len * sr: torchaudio.save(split_filepath, audio[:, split_start:split_end], sr) split_filepaths.append(split_filepath) - actual_splits.append(split_start/sr) - split_durations.append((split_end-split_start)/sr) + actual_splits.append(split_start / sr) + split_durations.append((split_end - split_start) / sr) split_start = split_end # Write last split - split_filepath = os.path.join(path, filename[:-4] + '.{}_of_{}.wav'.format(1+len(splits), 1+len(splits))) - last_frame = len(audio[0])-1 + split_filepath = os.path.join( + path, filename[:-4] + '.{}_of_{}.wav'.format(1 + len(splits), 1 + len(splits)) + ) + last_frame = len(audio[0]) - 1 # skip audios that are too short - if last_frame-split_start > self.min_len * sr and last_frame-split_start < (self.suggested_max_len + 1)*sr: + if ( + last_frame - split_start > self.min_len * sr + and last_frame - split_start < (self.suggested_max_len + 1) * sr + ): torchaudio.save(split_filepath, audio[:, split_start:], sr) split_filepaths.append(split_filepath) - split_durations.append((last_frame-split_start)/sr) - actual_splits.append(split_start/sr) + split_durations.append((last_frame - split_start) / sr) + actual_splits.append(split_start / sr) # Add split_filepaths to results without split_filepaths field and resampled_audio_filepath replaced # with the corresponding splits @@ -146,7 +150,7 @@ def process(self): class JoinSplitAudioMetadata(BaseProcessor): """A processor for joining metadata of previously split audio files. - This processor combines the metadata (transcripts and alignments) of audio files + This processor combines the metadata (transcripts and alignments) of audio files that were previously split by the SplitLongAudio processor. It adjusts timestamps and concatenates transcripts to recreate the original audio's metadata. @@ -156,9 +160,8 @@ class JoinSplitAudioMetadata(BaseProcessor): Returns: The same data as in the input manifest, but with split audio files joined together. """ - def __init__(self, - **kwargs - ): + + def __init__(self, **kwargs): super().__init__(**kwargs) def process(self): @@ -209,7 +212,5 @@ def process(self): # Remove 'split_filepaths' field from meta entry to turn it into a real entry del meta_entry['split_filepaths'] fp_w.write(f"{json.dumps(meta_entry)}\n") - - fp_w.close() - + fp_w.close() diff --git a/sdp/processors/tts/text.py b/sdp/processors/tts/text.py index 37dbb862..b0aab9ee 100644 --- a/sdp/processors/tts/text.py +++ b/sdp/processors/tts/text.py @@ -13,10 +13,19 @@ # limitations under the License. import json -from sdp.processors.base_processor import BaseProcessor, BaseParallelProcessor, DataEntry -from sdp.utils.common import load_manifest, save_manifest -from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer + from nemo.collections.nlp.models import PunctuationCapitalizationModel +from nemo_text_processing.inverse_text_normalization.inverse_normalize import ( + InverseNormalizer, +) + +from sdp.processors.base_processor import ( + BaseParallelProcessor, + BaseProcessor, + DataEntry, +) +from sdp.utils.common import load_manifest, save_manifest + class InverseTextNormalizationProcessor(BaseParallelProcessor): """This processor performs inverse text normalization on text data. @@ -39,15 +48,14 @@ class InverseTextNormalizationProcessor(BaseParallelProcessor): output_manifest_file: ${workspace_dir}/manifest_itn.json language: "en" """ - def __init__(self, - language="en", - **kwargs): + + def __init__(self, language="en", **kwargs): super().__init__(**kwargs) self.normalizer = InverseNormalizer(lang=language) - + def read_manifest(self): - ''' Reads metadata from JSONL file in the input manifest - and converts it to data entries ''' + '''Reads metadata from JSONL file in the input manifest + and converts it to data entries''' dataset_entries = load_manifest(self.input_manifest_file, encoding="utf8") @@ -84,21 +92,17 @@ class PunctuationAndCapitalizationOnSegmentsProcessor(BaseProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_pnc.json """ - def __init__(self, - model_name="punctuation_en_bert", - model_path=None, - batch_size=64, - **kwargs): + def __init__(self, model_name="punctuation_en_bert", model_path=None, batch_size=64, **kwargs): super().__init__(**kwargs) if model_path is not None: self.pnc_model = PunctuationCapitalizationModel.restore_from(model_path) else: self.pnc_model = PunctuationCapitalizationModel.from_pretrained(model_name) - - self.batch_size= batch_size + + self.batch_size = batch_size self.pnc_model.cuda() - + def process(self): manifest = load_manifest(self.input_manifest_file) @@ -109,7 +113,7 @@ def process(self): if "text" in segment: text = segment["text"] all_text.append(text) - + text_PNC = self.pnc_model.add_punctuation_capitalization(all_text, batch_size=self.batch_size) i = 0 @@ -117,11 +121,12 @@ def process(self): for segment in metadata["segments"]: if "text" in segment: segment["text"] = text_PNC[i] - i+=1 + i += 1 results.append(metadata) save_manifest(results, self.output_manifest_file) + class PunctuationAndCapitalizationProcessor(BaseProcessor): """This processor performs punctuation and capitalization on text data. @@ -143,45 +148,44 @@ class PunctuationAndCapitalizationProcessor(BaseProcessor): input_manifest_file: ${workspace_dir}/manifest.json output_manifest_file: ${workspace_dir}/manifest_pnc.json """ - def __init__(self, - model_name="punctuation_en_bert", - model_path=None, - batch_size=64, - **kwargs): + def __init__(self, model_name="punctuation_en_bert", model_path=None, batch_size=64, **kwargs): super().__init__(**kwargs) if model_path is not None: self.pnc_model = PunctuationCapitalizationModel.restore_from(model_path) else: self.pnc_model = PunctuationCapitalizationModel.from_pretrained(model_name) - - self.batch_size= batch_size + + self.batch_size = batch_size self.pnc_model.cuda() - + def process(self): manifest = load_manifest(self.input_manifest_file) all_text = [] - + for metadata in manifest: - is_segmented_entry = ('split_filepaths' in metadata and metadata['split_filepaths'] is None) or ('split_filepaths' not in metadata) - if is_segmented_entry and ('text' in metadata and metadata['text'] != ''): + is_segmented_entry = ('split_filepaths' in metadata and metadata['split_filepaths'] is None) or ( + 'split_filepaths' not in metadata + ) + if is_segmented_entry and ('text' in metadata and metadata['text'] != ''): text = ' '.join([x['word'] for x in metadata['alignment']]).strip() all_text.append(text) - + text_PNC = self.pnc_model.add_punctuation_capitalization(all_text, batch_size=self.batch_size) i = 0 with open(self.output_manifest_file, 'w') as f: for metadata in manifest: - is_segmented_entry = ('split_filepaths' in metadata and metadata['split_filepaths'] is None) or ('split_filepaths' not in metadata) - if is_segmented_entry and ('text' in metadata and metadata['text'] != ''): + is_segmented_entry = ('split_filepaths' in metadata and metadata['split_filepaths'] is None) or ( + 'split_filepaths' not in metadata + ) + if is_segmented_entry and ('text' in metadata and metadata['text'] != ''): pnc_words = text_PNC[i].split() pnc_words_idx = 0 for word in metadata['alignment']: if word['word'] != '': word['word'] = pnc_words[pnc_words_idx] pnc_words_idx += 1 - i+=1 + i += 1 f.write(json.dumps(metadata) + "\n") - diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 6ddf27f4..5370b927 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import tempfile import uuid from typing import List, Optional -import psutil -import json import hydra +import psutil from omegaconf import OmegaConf, open_dict from sdp.logging import logger - from sdp.utils.import_manager import ImportManager # registering new resolvers to simplify config files @@ -46,16 +45,18 @@ logger.addHandler(handler) logger.propagate = False + def update_processor_imports(config_path: str, init_file: str = None): """ Update processor imports based on config file. - + Args: config_path: Path to the YAML config file init_file: Optional path to __init__.py file to update """ try: import yaml + manager = ImportManager() manager.sync_with_config(config_path, init_file) logger.info(f"Successfully updated imports for config: {config_path}") @@ -120,13 +121,14 @@ def run_processors(cfg): if cfg.get("use_import_manager", False): try: import yaml + yaml_path = cfg.get("config_path") if not yaml_path: raise ValueError("No configuration path provided in 'config_path'. Please specify the path.") if not os.path.exists(yaml_path): raise FileNotFoundError(f"Configuration file not found: {yaml_path}") - + logger.info(f"Managing imports for config: {yaml_path}") manager = ImportManager() manager.sync_with_config(yaml_path) @@ -144,11 +146,12 @@ def run_processors(cfg): # Detecting dask try: from dask.distributed import Client + dask_available = True except ImportError: logger.warning("Dask not installed; using multiprocessing for all processors") dask_available = False - + # look for global directions in cfg for dask usage global_use_dask = bool(cfg.get("use_dask", True)) and dask_available @@ -156,7 +159,7 @@ def run_processors(cfg): if processors_to_run == "all": processors_to_run = ":" selected_cfgs = select_subset(cfg.processors, processors_to_run) - + # filtering out any processors that have should_run=False processors_cfgs = [] for processor_cfg in selected_cfgs: @@ -169,9 +172,7 @@ def run_processors(cfg): "Specified to run the following processors: %s ", [proc_cfg["_target_"] for proc_cfg in processors_cfgs], ) - - - + processors = [] # Create a temporary directory to hold intermediate files if needed. with tempfile.TemporaryDirectory() as tmp_dir: @@ -187,7 +188,7 @@ def run_processors(cfg): with open_dict(processors_cfgs[0]): processors_cfgs[0]["input_manifest_file"] = cfg.processors[idx - 1]["output_manifest_file"] break - + for idx, processor_cfg in enumerate(processors_cfgs): logger.info('=> Building processor "%s"', processor_cfg["_target_"]) @@ -205,9 +206,9 @@ def run_processors(cfg): if idx != len(processors_cfgs) - 1 and "input_manifest_file" not in processors_cfgs[idx + 1]: with open_dict(processors_cfgs[idx + 1]): processors_cfgs[idx + 1]["input_manifest_file"] = processor_cfg["output_manifest_file"] - - #check if we have processor level directions of using dask - flag=processor_cfg.get("use_dask", None) + + # check if we have processor level directions of using dask + flag = processor_cfg.get("use_dask", None) # if no processor-specific flag, fallback to global; otherwise use provided value if flag is None: @@ -224,7 +225,6 @@ def run_processors(cfg): # Start Dask client if any processor requires it dask_client = None if any(p.use_dask for p in processors): - try: num_cpus = psutil.cpu_count(logical=False) or 4 logger.info(f"Starting Dask client with {num_cpus} workers") @@ -249,4 +249,5 @@ def run_processors(cfg): dask_client.close(timeout="60s") logger.info("Dask client shutdown complete") -#tmp_dir is removed here after all processing finishes. !!! + +# tmp_dir is removed here after all processing finishes. !!! diff --git a/sdp/utils/__init__.py b/sdp/utils/__init__.py index 2223b231..15a876ae 100644 --- a/sdp/utils/__init__.py +++ b/sdp/utils/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. -from sdp.utils.bootstrap_estimates import BootstrapProcessor \ No newline at end of file +from sdp.utils.bootstrap_estimates import BootstrapProcessor diff --git a/sdp/utils/apply_operators.py b/sdp/utils/apply_operators.py index 80db65b8..9bb779f1 100644 --- a/sdp/utils/apply_operators.py +++ b/sdp/utils/apply_operators.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator import ast +import operator import re from typing import Any, Dict @@ -168,4 +168,4 @@ def _eval(node): expression = re.sub(rf'{re.escape(var_prefix)}(\w+)', r'\1', expression) tree = ast.parse(expression, mode='eval') - return _eval(tree.body) \ No newline at end of file + return _eval(tree.body) diff --git a/sdp/utils/bootstrap_estimates.py b/sdp/utils/bootstrap_estimates.py index efdebe83..6a476068 100644 --- a/sdp/utils/bootstrap_estimates.py +++ b/sdp/utils/bootstrap_estimates.py @@ -14,17 +14,21 @@ import json from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + import numpy as np from tqdm import tqdm + from sdp.processors.base_processor import BaseProcessor -from typing import List, Dict, Union, Optional, Tuple + from . import metrics_computation as metrics + class BootstrapProcessor(BaseProcessor): """This processor evaluates ASR performance metrics using bootstrapped confidence intervals. - It calculates metrics such as Word Error Rate (WER), Character Error Rate (CER), Word Match - Rate (WMR), character rate, and word rate. When `calculate_pairwise` is set to `True`, it also + It calculates metrics such as Word Error Rate (WER), Character Error Rate (CER), Word Match + Rate (WMR), character rate, and word rate. When `calculate_pairwise` is set to `True`, it also computes the Probability of Improvement (POI) between different ASR models. This implementation leverages bootstrapping to provide robust confidence intervals for each metric, @@ -64,10 +68,10 @@ def __init__( self, bootstrap_manifest_files: List[str], raw_data_dir: str, - output_file: str, + output_file: str, num_bootstraps: int = 1000, bootstrap_sample_ratio: float = 1.0, - calculate_pairwise: bool = True, + calculate_pairwise: bool = True, metric_type: str = 'wer', text_key: str = 'text', pred_text_key: str = 'pred_text', @@ -86,23 +90,24 @@ def __init__( self.calculate_pairwise = calculate_pairwise self.metric_type = metric_type.lower() self.text_key = text_key - self.pred_text_key = pred_text_key + self.pred_text_key = pred_text_key self.ci_lower = ci_lower - self.ci_upper = ci_upper - self.random_state = random_state - + self.ci_upper = ci_upper + self.random_state = random_state if self.random_state is not None: np.random.seed(self.random_state) if self.metric_type not in ['wer', 'cer', 'wmr', 'charrate', 'wordrate']: - raise ValueError(f"Invalid metric_type '{self.metric_type}'! Must be one of ['wer', 'cer', 'wmr', 'charrate', 'wordrate']") + raise ValueError( + f"Invalid metric_type '{self.metric_type}'! Must be one of ['wer', 'cer', 'wmr', 'charrate', 'wordrate']" + ) def read_manifest(self, manifest_path: Path) -> List[Dict[str, Union[str, float]]]: manifest_data = [] with manifest_path.open('r', encoding='utf-8') as f: for line in f: - data = json.loads(line.strip()) + data = json.loads(line.strip()) manifest_data.append(data) return manifest_data @@ -125,7 +130,9 @@ def calculate_metric(self, text: str, pred_text: str, duration: Optional[float] else: raise ValueError(f"Unsupported metric_type: {self.metric_type}") - def bootstrap_metric(self, hypotheses: List[str], references: List[str], durations: Optional[List[float]] = None) -> np.ndarray: + def bootstrap_metric( + self, hypotheses: List[str], references: List[str], durations: Optional[List[float]] = None + ) -> np.ndarray: """ Bootstraps metric computation (WER, CER, etc.) to calculate confidence intervals. @@ -147,15 +154,25 @@ def bootstrap_metric(self, hypotheses: List[str], references: List[str], duratio sampled_references = [references[i] for i in indices] if durations: sampled_durations = [durations[i] for i in indices] - metric = [self.calculate_metric(sampled_references[i], sampled_hypotheses[i], sampled_durations[i]) - for i in range(sample_size)] + metric = [ + self.calculate_metric(sampled_references[i], sampled_hypotheses[i], sampled_durations[i]) + for i in range(sample_size) + ] else: - metric = [self.calculate_metric(sampled_references[i], sampled_hypotheses[i]) for i in range(sample_size)] + metric = [ + self.calculate_metric(sampled_references[i], sampled_hypotheses[i]) for i in range(sample_size) + ] metric_bootstrap.append(np.mean(metric)) return np.array(metric_bootstrap) - def bootstrap_wer_difference(self, predictions1: List[str], predictions2: List[str], references: List[str], durations: Optional[List[float]] = None) -> Tuple[np.ndarray, float]: + def bootstrap_wer_difference( + self, + predictions1: List[str], + predictions2: List[str], + references: List[str], + durations: Optional[List[float]] = None, + ) -> Tuple[np.ndarray, float]: """ Calculates the bootstrapped difference in metrics between two sets of predictions and the probability of improvement. @@ -182,8 +199,14 @@ def bootstrap_wer_difference(self, predictions1: List[str], predictions2: List[s if durations: sampled_durations = [durations[i] for i in indices] - metric1 = [self.calculate_metric(sampled_refs[i], sampled_pred1[i], sampled_durations[i]) for i in range(sample_size)] - metric2 = [self.calculate_metric(sampled_refs[i], sampled_pred2[i], sampled_durations[i]) for i in range(sample_size)] + metric1 = [ + self.calculate_metric(sampled_refs[i], sampled_pred1[i], sampled_durations[i]) + for i in range(sample_size) + ] + metric2 = [ + self.calculate_metric(sampled_refs[i], sampled_pred2[i], sampled_durations[i]) + for i in range(sample_size) + ] else: metric1 = [self.calculate_metric(sampled_refs[i], sampled_pred1[i]) for i in range(sample_size)] metric2 = [self.calculate_metric(sampled_refs[i], sampled_pred2[i]) for i in range(sample_size)] @@ -200,7 +223,7 @@ def prepare(self): def process(self): """ - Main processing function that loads data, performs metric bootstrapping and optionally + Main processing function that loads data, performs metric bootstrapping and optionally pairwise comparison, and saves the results to a JSON file. """ self.prepare() @@ -240,7 +263,7 @@ def process(self): results["individual_results"][bootstrap_manifest_files[idx].name] = { f"mean_{self.metric_type}": mean_metric, "ci_lower": ci_lower_value, - "ci_upper": ci_upper_value + "ci_upper": ci_upper_value, } # Pairwise comparison between models (only if calculate_pairwise is True) @@ -250,24 +273,29 @@ def process(self): for i in range(num_files): for j in range(i + 1, num_files): if durations: - delta_metric_bootstrap, poi = self.bootstrap_wer_difference(predicted_texts[i], predicted_texts[j], ground_truth, durations[i]) + delta_metric_bootstrap, poi = self.bootstrap_wer_difference( + predicted_texts[i], predicted_texts[j], ground_truth, durations[i] + ) else: - delta_metric_bootstrap, poi = self.bootstrap_wer_difference(predicted_texts[i], predicted_texts[j], ground_truth) + delta_metric_bootstrap, poi = self.bootstrap_wer_difference( + predicted_texts[i], predicted_texts[j], ground_truth + ) mean_delta_metric = np.mean(delta_metric_bootstrap) ci_lower_value = np.percentile(delta_metric_bootstrap, self.ci_lower) ci_upper_value = np.percentile(delta_metric_bootstrap, self.ci_upper) - results["pairwise_comparisons"].append({ - "file_1": bootstrap_manifest_files[i].name, - "file_2": bootstrap_manifest_files[j].name, - f"delta_{self.metric_type}_mean": mean_delta_metric, - "ci_lower": ci_lower_value, - "ci_upper": ci_upper_value, - "poi": poi - }) + results["pairwise_comparisons"].append( + { + "file_1": bootstrap_manifest_files[i].name, + "file_2": bootstrap_manifest_files[j].name, + f"delta_{self.metric_type}_mean": mean_delta_metric, + "ci_lower": ci_lower_value, + "ci_upper": ci_upper_value, + "poi": poi, + } + ) output_path = Path(self.output_file) with output_path.open('w') as out_file: json.dump(results, out_file, indent=4) - diff --git a/sdp/utils/common.py b/sdp/utils/common.py index aa695b0e..e57db973 100644 --- a/sdp/utils/common.py +++ b/sdp/utils/common.py @@ -19,7 +19,7 @@ import urllib import zipfile from pathlib import Path -from typing import Dict, List, Union, Any, Optional +from typing import Any, Dict, List, Optional, Union import wget @@ -35,6 +35,7 @@ def load_manifest(manifest: Union[Path, str], encoding: Optional[str] = None) -> result.append(data) return result + def save_manifest(manifest: List[Dict[str, Any]], manifest_file: Union[Path, str]): with open(manifest_file, 'w') as f: for item in manifest: diff --git a/sdp/utils/import_manager.py b/sdp/utils/import_manager.py index 5458892e..1c7f2bf2 100644 --- a/sdp/utils/import_manager.py +++ b/sdp/utils/import_manager.py @@ -4,26 +4,29 @@ import os from pathlib import Path from typing import Dict, Optional, Set + import yaml from sdp.logging import logger + class ImportManager: """ The ImportManager class is a utility designed to manage dynamic imports for a specific Python package based on a provided YAML configuration. This class simplifies the process of selectively importing only the necessary components, - enabling the creation of a custom __init__.py file with imports for required processors. - By doing so, it ensures that users only need to install the libraries they actually use, + enabling the creation of a custom __init__.py file with imports for required processors. + By doing so, it ensures that users only need to install the libraries they actually use, reducing unnecessary dependencies. - + To eable the ImportManager, set the `use_import_manager` key to `True` in the YAML config file. (Or provide it as an argument to main.py) use_import_manager: True """ + def __init__(self, base_package: str = "sdp"): self.base_package = base_package self.package_path = self._find_package_path() - + def _find_package_path(self) -> Path: try: package = importlib.import_module(self.base_package) @@ -40,24 +43,19 @@ def _get_processor_import(self, target: str) -> Optional[str]: module_path, class_name = target.rsplit('.', 1) return f"from {module_path} import {class_name}" except ValueError as e: - # Raised if the target does not contain a '.' + # Raised if the target does not contain a '.' logger.warning(f"Invalid target format for import: '{target}'. Expected '.'. Error: {e}") except AttributeError as e: - # Raised if the target module or class does not exist + # Raised if the target module or class does not exist logger.warning(f"Invalid target type for import: {type(target)}. Error: {e}") except Exception as e: logger.warning(f"Could not process import for {target}: {e}") return None - - - - - def get_required_imports(self, yaml_config: str) -> Set[str]: with open(yaml_config, 'r') as f: config = yaml.safe_load(f) - + required_imports = set() if 'processors' in config: for processor in config['processors']: @@ -66,7 +64,7 @@ def get_required_imports(self, yaml_config: str) -> Set[str]: if import_stmt: required_imports.add(import_stmt) logger.debug(f"Found required processor: {processor['_target_']}") - + return required_imports def sync_with_config(self, yaml_config: str, init_file: Optional[str] = None) -> None: @@ -88,29 +86,31 @@ def sync_with_config(self, yaml_config: str, init_file: Optional[str] = None) -> # Parse YAML config and get required imports required_imports = self.get_required_imports(yaml_config) - + # Mention that this file is auto-generated new_content = [] if "let's import all supported processors" in current_content: # Keep the header comment if it exists - new_content.append("# This was automaticly generated, to disable: set use_import_manager: False in yaml config\n") - + new_content.append( + "# This was automaticly generated, to disable: set use_import_manager: False in yaml config\n" + ) + # Add imports for import_stmt in sorted(required_imports): new_content.append(import_stmt) - + # Write the new content init_file.parent.mkdir(parents=True, exist_ok=True) with open(init_file, 'w') as f: f.write('\n'.join(new_content)) - + logger.info(f"Successfully updated {init_file} with required imports") def setup_import_hooks(): """Set up import hooks for automatic import management.""" original_yaml_load = yaml.safe_load - + def yaml_load_hook(stream): result = original_yaml_load(stream) if isinstance(result, dict) and 'processors' in result: @@ -119,20 +119,20 @@ def yaml_load_hook(stream): if frame.f_code.co_name != 'yaml_load_hook': break frame = frame.f_back - + if frame: caller_file = frame.f_code.co_filename if isinstance(stream, str): yaml_path = stream else: yaml_path = os.path.abspath(caller_file) - + manager = ImportManager() try: manager.sync_with_config(yaml_path) except Exception as e: logger.warning(f"Failed to sync imports: {e}") - + return result - - yaml.safe_load = yaml_load_hook \ No newline at end of file + + yaml.safe_load = yaml_load_hook diff --git a/sdp/utils/ipl_utils.py b/sdp/utils/ipl_utils.py index 07d50c5d..c5e9badc 100644 --- a/sdp/utils/ipl_utils.py +++ b/sdp/utils/ipl_utils.py @@ -18,6 +18,7 @@ from omegaconf import OmegaConf + def separate_multiple_transcriptions(inference_config: dict) -> Tuple[List[str], Optional[List[str]]]: """ Separates and returns the manifest and tarred audio file paths from the configuration. @@ -29,7 +30,7 @@ def separate_multiple_transcriptions(inference_config: dict) -> Tuple[List[str], - A list of manifest file paths. - An Optional list of tarred audio file paths, or None if not applicable. """ - + if hasattr(inference_config.predict_ds, "is_tarred") and inference_config.predict_ds.is_tarred: tarred_audio_filepaths = inference_config.predict_ds.tarred_audio_filepaths manifest_filepaths = inference_config.predict_ds.manifest_filepath @@ -120,7 +121,6 @@ def create_transcribed_manifests( # Open and read the original predictions_all.json file with open(transcripted_name, 'w', encoding='utf-8') as f: with open(prediction_name, 'r', encoding='utf-8') as pred_f: - for line in pred_f.readlines(): data_entry = json.loads(line) if 'text' in data_entry: @@ -187,13 +187,12 @@ def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[Li json.dump(data_entry, f, ensure_ascii=False) f.write("\n") - shard_manifest_filepath = os.path.join( - prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" - ) + shard_manifest_filepath = os.path.join(prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json") all_manifest_filepaths.append([shard_manifest_filepath]) return all_manifest_filepaths + def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: """ Updates transcriptions by merging predicted data with transcribed manifest data. @@ -217,7 +216,7 @@ def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: for line in f: data_entry = json.loads(line) path = data_entry['audio_filepath'] - + predicted_data[path] = data_entry full_path = os.path.join(prediction_filepath, f"transcribed_manifest.json") all_data_entries = [] @@ -227,7 +226,6 @@ def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: count += 1 data_entry = json.loads(line) all_data_entries.append(data_entry) - output_filename = os.path.join(prediction_filepath, f"transcribed_manifest.json") with open(output_filename, 'w') as f: @@ -303,7 +301,6 @@ def update_training_sets( print(f"final_cache_manifests {final_cache_manifests}") merged_config.model.train_ds.manifest_filepath += final_cache_manifests - return merged_config diff --git a/sdp/utils/nemo_run_utils.py b/sdp/utils/nemo_run_utils.py index 5cbd8575..dcf87df0 100644 --- a/sdp/utils/nemo_run_utils.py +++ b/sdp/utils/nemo_run_utils.py @@ -12,25 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import logging import os from functools import lru_cache + from nemo_run.core.tunnel import LocalTunnel, SSHTunnel from omegaconf import DictConfig, OmegaConf + +from sdp.utils import ipl_utils from sdp.utils.skills_utils import ( - get_mounts_from_config, - check_if_mounted, add_task, + check_if_mounted, + get_mounts_from_config, run_exp, ) -import logging -import copy -from sdp.utils import ipl_utils + + @lru_cache(maxsize=2) def get_tunnel(**ssh_tunnel): return SSHTunnel(**ssh_tunnel) - def add_mount_path(mount_source: str, mount_dest: str, cluster_config): """ Add a mount path to the cluster config. @@ -190,6 +193,7 @@ def create_remote_config(config: dict, config_name: str, config_directory: str, raise ValueError(f"Unsupported executor: {cluster_config.get('executor')}") return config_filepath + def create_remote_inference_config(cluster_config, config_directory: str, inference_config, checkpoint_path): """ Utility to create and write remote inference configuration files for a cluster setup. diff --git a/sdp/utils/skills_utils.py b/sdp/utils/skills_utils.py index ac536043..915de56e 100644 --- a/sdp/utils/skills_utils.py +++ b/sdp/utils/skills_utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -#This file is maintained in sync with `nemo_skills/pipeline/utils.py` +# +# This file is maintained in sync with `nemo_skills/pipeline/utils.py` # and is intended to be copied as-is to ensure consistency across projects. import logging @@ -31,6 +31,7 @@ import nemo_run as run import yaml from huggingface_hub import get_token + try: from invoke import StreamWatcher except ImportError: @@ -75,7 +76,6 @@ def __post_init__(self): } - def register_external_repo(metadata: RepoMetadata): """Register an external repo to be packaged with the code in the experiment. @@ -906,14 +906,14 @@ def get_executor( env_vars["SLURM_MASTER_NODE"] = "$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n1)" else: # master node will be within the same group - env_vars["SLURM_MASTER_NODE"] = ( - f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{het_group} | head -n1)" - ) + env_vars[ + "SLURM_MASTER_NODE" + ] = f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{het_group} | head -n1)" # in addition defining master nodes for all groups to allow communication for group in range(total_het_groups): - env_vars[f"SLURM_MASTER_NODE_HET_GROUP_{group}"] = ( - f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{group} | head -n1)" - ) + env_vars[ + f"SLURM_MASTER_NODE_HET_GROUP_{group}" + ] = f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{group} | head -n1)" partition = partition or cluster_config.get("partition") if 'timeouts' not in cluster_config: diff --git a/tests/prepare_test_data/prepare_coraa_data.py b/tests/prepare_test_data/prepare_coraa_data.py index 4e472e6a..e23a4ae9 100644 --- a/tests/prepare_test_data/prepare_coraa_data.py +++ b/tests/prepare_test_data/prepare_coraa_data.py @@ -14,21 +14,25 @@ import argparse +import csv +import glob import os +import random import shutil +import subprocess # For external commands (e.g., for rar) import tempfile -from pathlib import Path import zipfile -import subprocess # For external commands (e.g., for rar) -import random -import csv -import glob +from pathlib import Path + def create_zip_archive(source_dir, output_path): with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: for root, dirs, files in os.walk(source_dir): for file in files: - zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(source_dir, '..'))) + zipf.write( + os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(source_dir, '..')) + ) + def create_rar_archive(source_dir, output_path): parent_dir = os.path.dirname(source_dir) @@ -36,6 +40,7 @@ def create_rar_archive(source_dir, output_path): command = ['rar', 'a', '-r', '-v20m', output_path, target_folder_name] subprocess.run(command, check=True, cwd=parent_dir) + def sample_and_copy_entries(transcript_path, tmpdir_path, num_entries, extracted_data_path, output_metadata_path): with open(transcript_path, "rt", encoding="utf8") as fin: reader = csv.reader(fin) @@ -53,6 +58,7 @@ def sample_and_copy_entries(transcript_path, tmpdir_path, num_entries, extracted shutil.copy(src_path, tgt_path) writer.writerow(row) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Preparing Coraa test data") parser.add_argument("--extracted_data_path", required=True, help="Path to the downloaded and extracted data.") @@ -67,7 +73,9 @@ def sample_and_copy_entries(transcript_path, tmpdir_path, num_entries, extracted for split in ["train", "dev", "test"]: transcript_path = Path(args.extracted_data_path) / f"metadata_{split}_final.csv" output_metadata_path = Path(args.test_data_folder) / f"metadata_{split}_final.csv" - sample_and_copy_entries(transcript_path, tmpdir_path, args.num_entries, args.extracted_data_path, output_metadata_path) + sample_and_copy_entries( + transcript_path, tmpdir_path, args.num_entries, args.extracted_data_path, output_metadata_path + ) archive_path = os.path.join(args.test_data_folder, split) source_dir = os.path.join(tmpdir_path, split) if split in ['dev', 'test']: @@ -78,4 +86,4 @@ def sample_and_copy_entries(transcript_path, tmpdir_path, num_entries, extracted create_rar_archive(source_dir, archive_path) pattern = os.path.join(args.test_data_folder, 'train*.rar') for file_path in glob.glob(pattern): - shutil.move(file_path,train_folder) + shutil.move(file_path, train_folder) diff --git a/tests/prepare_test_data/prepare_fleurs_data.py b/tests/prepare_test_data/prepare_fleurs_data.py index 3059e22a..77ba49a8 100644 --- a/tests/prepare_test_data/prepare_fleurs_data.py +++ b/tests/prepare_test_data/prepare_fleurs_data.py @@ -15,8 +15,8 @@ """Will take the downloaded .tsv file and audios directory and create a version with only X entries.""" import argparse -import os import csv +import os import shutil import tarfile import tempfile @@ -25,7 +25,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Preparing Fleurs test data") parser.add_argument("--extracted_tsv_path", required=True, help="Path to the downloaded .tsv file.") - parser.add_argument("--extracted_audios_dir", required=True, help="Path to the downloaded and extracted audios directory.") + parser.add_argument( + "--extracted_audios_dir", required=True, help="Path to the downloaded and extracted audios directory." + ) parser.add_argument( "--archive_file_stem", required=True, @@ -36,24 +38,25 @@ args = parser.parse_args() os.makedirs(args.test_data_folder, exist_ok=True) - + with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) - with open(args.extracted_tsv_path, "rt", encoding="utf8") as fin, \ - open(os.path.join(args.test_data_folder, args.archive_file_stem + '.tsv'), "wt", encoding="utf8") as fout: - csv_reader = csv.reader(fin, delimiter='\t') # creating CSV reader object - csv_writer = csv.writer(fout, delimiter='\t') # creating CSV reader object - - for idx, row in enumerate(csv_reader): - if idx == args.num_entries: - break - - src_audio_path = os.path.join(args.extracted_audios_dir, row[1]) - dst_audio_path = os.path.join(tmpdir, row[1]) - shutil.copy(src_audio_path, dst_audio_path) - - csv_writer.writerow(row) - + with open(args.extracted_tsv_path, "rt", encoding="utf8") as fin, open( + os.path.join(args.test_data_folder, args.archive_file_stem + '.tsv'), "wt", encoding="utf8" + ) as fout: + csv_reader = csv.reader(fin, delimiter='\t') # creating CSV reader object + csv_writer = csv.writer(fout, delimiter='\t') # creating CSV reader object + + for idx, row in enumerate(csv_reader): + if idx == args.num_entries: + break + + src_audio_path = os.path.join(args.extracted_audios_dir, row[1]) + dst_audio_path = os.path.join(tmpdir, row[1]) + shutil.copy(src_audio_path, dst_audio_path) + + csv_writer.writerow(row) + with tarfile.open(os.path.join(args.test_data_folder, f"{args.archive_file_stem}.tar.gz"), "w:gz") as tar: # has to be the same as what's before .tar.gz tar.add(tmpdir, arcname=args.archive_file_stem) diff --git a/tests/prepare_test_data/prepare_hifitts2_data.py b/tests/prepare_test_data/prepare_hifitts2_data.py index 9a83fad1..8bd60440 100644 --- a/tests/prepare_test_data/prepare_hifitts2_data.py +++ b/tests/prepare_test_data/prepare_hifitts2_data.py @@ -17,8 +17,8 @@ import argparse import json import os -from pathlib import Path import shutil +from pathlib import Path if __name__ == "__main__": parser = argparse.ArgumentParser("Preparing HiFiTTS-2 test data") @@ -30,13 +30,25 @@ ) parser.add_argument("--test_data_folder", required=True, type=Path, help="Where to place the prepared data") parser.add_argument( - "--manifest_filename", default="manifest_22khz.json", type=str, required=False, help="Name of manifest manifest." + "--manifest_filename", + default="manifest_22khz.json", + type=str, + required=False, + help="Name of manifest manifest.", ) parser.add_argument( - "--chapters_filename", default="chapters_22khz.json", type=str, required=False, help="Name of chapter manifest." + "--chapters_filename", + default="chapters_22khz.json", + type=str, + required=False, + help="Name of chapter manifest.", ) parser.add_argument( - "--error_filename", default="errors_22khz.json", type=str, required=False, help="Name of chapter error manifest." + "--error_filename", + default="errors_22khz.json", + type=str, + required=False, + help="Name of chapter error manifest.", ) parser.add_argument("--num_entries", default=20, type=int, help="How many entries to keep from each manifest") @@ -69,4 +81,4 @@ input_path = input_audio_dir / audio_filepath output_path = output_audio_dir / audio_filepath output_path.parent.mkdir(exist_ok=True, parents=True) - shutil.copy(src=input_path, dst=output_path) \ No newline at end of file + shutil.copy(src=input_path, dst=output_path) diff --git a/tests/prepare_test_data/prepare_huggingface_data.py b/tests/prepare_test_data/prepare_huggingface_data.py index d74747c6..845cbc6e 100644 --- a/tests/prepare_test_data/prepare_huggingface_data.py +++ b/tests/prepare_test_data/prepare_huggingface_data.py @@ -15,16 +15,18 @@ """Will take the downloaded tar file and create a version with only X entries.""" import argparse +import itertools import os import tempfile -import itertools from pathlib import Path if __name__ == "__main__": - from datasets import load_dataset, Dataset, load_from_disk - + from datasets import Dataset, load_dataset, load_from_disk + parser = argparse.ArgumentParser("Preparing TarteelAI's EveryAyah test data") - parser.add_argument("--dataset_name", required=True, help="Hugging Face dataset name. E.g., 'tarteel-ai/everyayah'") + parser.add_argument( + "--dataset_name", required=True, help="Hugging Face dataset name. E.g., 'tarteel-ai/everyayah'" + ) parser.add_argument( "--archive_file_stem", required=True, @@ -35,12 +37,12 @@ parser.add_argument("--test_data_folder", required=True, help="Where to place the prepared data") args = parser.parse_args() - + os.makedirs(args.test_data_folder, exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) - + dataset = load_dataset(args.dataset_name, split="train", streaming=True) sampled_dataset = list(itertools.islice(dataset, args.num_entries)) sampled_dataset = Dataset.from_list(sampled_dataset) - sampled_dataset.save_to_disk(os.path.join(args.test_data_folder, f"{args.archive_file_stem}.hf")) \ No newline at end of file + sampled_dataset.save_to_disk(os.path.join(args.test_data_folder, f"{args.archive_file_stem}.hf")) diff --git a/tests/prepare_test_data/prepare_masc_data.py b/tests/prepare_test_data/prepare_masc_data.py index 1eee512e..929bdf15 100644 --- a/tests/prepare_test_data/prepare_masc_data.py +++ b/tests/prepare_test_data/prepare_masc_data.py @@ -15,11 +15,11 @@ """Will take the downloaded tar file and create a version with only X entries.""" import argparse +import csv import os import shutil import tarfile import tempfile -import csv from pathlib import Path if __name__ == "__main__": @@ -34,42 +34,40 @@ parser.add_argument("--test_data_folder", required=True, help="Where to place the prepared data") args = parser.parse_args() - + # Define a dictionary to map splits to filenames - filename_map = { - "train": "clean_train.csv", - "dev": "clean_dev_meta.csv", - "test": "clean_test_meta.csv" - } - + filename_map = {"train": "clean_train.csv", "dev": "clean_dev_meta.csv", "test": "clean_test_meta.csv"} + with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) os.makedirs(tmpdir_path / "audios") os.makedirs(tmpdir_path / "subtitles") os.makedirs(tmpdir_path / "subsets") - + for split in ["train", "dev", "test"]: transcript_path = Path(args.extracted_data_path) / "subsets" / filename_map[split] - with open(transcript_path, "rt", encoding="utf8") as fin, open(tmpdir_path / "subsets" / filename_map[split], "wt", encoding="utf8") as fout: - csv_reader = csv.reader(fin) # creating CSV reader object - csv_writer = csv.writer(fout) # creating CSV reader object - - csv_writer.writerow(next(csv_reader)) # writing colomns line + with open(transcript_path, "rt", encoding="utf8") as fin, open( + tmpdir_path / "subsets" / filename_map[split], "wt", encoding="utf8" + ) as fout: + csv_reader = csv.reader(fin) # creating CSV reader object + csv_writer = csv.writer(fout) # creating CSV reader object + + csv_writer.writerow(next(csv_reader)) # writing colomns line for idx, row in enumerate(csv_reader): if idx == args.num_entries: break utt_id = row[0] - + # copying audio file src_audio_path = os.path.join(args.extracted_data_path, "audios", f"{utt_id}.wav") tgt_audio_path = os.path.join(tmpdir_path, "audios", f"{utt_id}.wav") shutil.copy(src_audio_path, tgt_audio_path) - + # copying transcription file src_transcript_path = os.path.join(args.extracted_data_path, "subtitles", f"{utt_id}.ar.vtt") tgt_transcript_path = os.path.join(tmpdir_path, "subtitles", f"{utt_id}.ar.vtt") shutil.copy(src_transcript_path, tgt_transcript_path) - + csv_writer.writerow(row) os.makedirs(args.test_data_folder, exist_ok=True) diff --git a/tests/prepare_test_data/prepare_mediaspeech_data.py b/tests/prepare_test_data/prepare_mediaspeech_data.py index a62b6ba2..0e455205 100644 --- a/tests/prepare_test_data/prepare_mediaspeech_data.py +++ b/tests/prepare_test_data/prepare_mediaspeech_data.py @@ -15,9 +15,9 @@ """Will take the downloaded tar file and create a version with only X entries.""" import argparse +import glob import os import shutil -import glob import tarfile import tempfile from pathlib import Path @@ -34,24 +34,24 @@ parser.add_argument("--test_data_folder", required=True, help="Where to place the prepared data") args = parser.parse_args() - + os.makedirs(args.test_data_folder, exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) - + audio_filepaths = glob.glob(f"{args.extracted_data_path}/*.flac") for idx, src_audio_filepath in enumerate(audio_filepaths): if idx == args.num_entries: break - + sample_id = os.path.basename(src_audio_filepath).split(".")[0] src_text_filepath = os.path.join(args.extracted_data_path, f"{sample_id}.txt") dst_text_filepath = os.path.join(tmpdir, f"{sample_id}.txt") dst_audio_filepath = os.path.join(tmpdir, f"{sample_id}.flac") - + shutil.copy(src_text_filepath, dst_text_filepath) shutil.copy(src_audio_filepath, dst_audio_filepath) - + with tarfile.open(os.path.join(args.test_data_folder, f"{args.archive_file_stem}.tar.gz"), "w:gz") as tar: # has to be the same as what's before .tar.gz - tar.add(tmpdir, arcname=args.archive_file_stem) \ No newline at end of file + tar.add(tmpdir, arcname=args.archive_file_stem) diff --git a/tests/prepare_test_data/prepare_mtedx_data.py b/tests/prepare_test_data/prepare_mtedx_data.py index 02279a66..109240bf 100644 --- a/tests/prepare_test_data/prepare_mtedx_data.py +++ b/tests/prepare_test_data/prepare_mtedx_data.py @@ -36,19 +36,17 @@ data_path = os.path.join(tmpdir_path, "data") os.makedirs(data_path, exist_ok=True) for split in ["train", "valid", "test"]: - vtt_path_dest= os.path.join(data_path, split, "vtt") - flac_path_dest= os.path.join(data_path, split, "wav") - os.makedirs(vtt_path_dest, exist_ok=True) - os.makedirs(flac_path_dest, exist_ok=True) - for idx, vtt_file in enumerate(os.listdir(os.path.join( - args.extracted_data_path, "data", split, "vtt"))): + vtt_path_dest = os.path.join(data_path, split, "vtt") + flac_path_dest = os.path.join(data_path, split, "wav") + os.makedirs(vtt_path_dest, exist_ok=True) + os.makedirs(flac_path_dest, exist_ok=True) + for idx, vtt_file in enumerate(os.listdir(os.path.join(args.extracted_data_path, "data", split, "vtt"))): if idx == args.num_entries: break flac_file = vtt_file.split(".")[0] + ".flac" - vtt_file_src = os.path.join(args.extracted_data_path,"data", split, "vtt", vtt_file) + vtt_file_src = os.path.join(args.extracted_data_path, "data", split, "vtt", vtt_file) flac_file_src = os.path.join(args.extracted_data_path, "data", split, "wav", flac_file) shutil.copy(vtt_file_src, vtt_path_dest) shutil.copy(flac_file_src, flac_path_dest) with tarfile.open(os.path.join(args.test_data_folder, f"mtedx_{args.language_id}.tgz"), "w:gz") as tar: tar.add(tmpdir, arcname=f"mtedx_{args.language_id}") - diff --git a/tests/prepare_test_data/prepare_ytc_data.py b/tests/prepare_test_data/prepare_ytc_data.py index d20ecba2..6045ae6f 100644 --- a/tests/prepare_test_data/prepare_ytc_data.py +++ b/tests/prepare_test_data/prepare_ytc_data.py @@ -15,11 +15,11 @@ """Will take the downloaded tar file and create a version with only X entries.""" import argparse +import json import os import shutil import tarfile import tempfile -import json from pathlib import Path if __name__ == "__main__": @@ -43,21 +43,21 @@ for idx, audio_file in enumerate(Path(args.extracted_data_path).glob("audios/*")): if idx == args.num_entries: break - + # Copy audio file to temp directory maintaining relative path rel_path = audio_file.relative_to(Path(args.extracted_data_path)) target_path = tmpdir_path / split / "audio" / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(audio_file, target_path) stem = audio_file.stem - + # Write manifest entry manifest_entry = { "audio_filepath": str(target_path.relative_to(tmpdir_path / split)), - "audio_item_id": stem + "audio_item_id": stem, } fout.write(f"{json.dumps(manifest_entry)}\n") - + os.makedirs(args.test_data_folder, exist_ok=True) with tarfile.open(os.path.join(args.test_data_folder, f"ytc_{args.language}.tar.gz"), "w:gz") as tar: # has to be the same as what's before .tar.gz diff --git a/tests/test_bootstrap_estimate.py b/tests/test_bootstrap_estimate.py index 18ac7bf7..6213e52c 100644 --- a/tests/test_bootstrap_estimate.py +++ b/tests/test_bootstrap_estimate.py @@ -1,57 +1,116 @@ import json import tempfile from pathlib import Path + from sdp.utils import BootstrapProcessor + def _write_manifest(manifest_path: Path, entries): with manifest_path.open("w") as f: for entry in entries: f.write(json.dumps(entry) + "\n") + def test_bootstrap_processor(): manifest1_data = [ - {"audio_filepath": "path1.wav", "duration": 3.744, "text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", - "pred_text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", "wer": 0.142857, "tokens": 7, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.142857}, - {"audio_filepath": "path2.wav", "duration": 5.76, "text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմինդրների արտադրությունը։", - "pred_text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կաղնիտների արտադրությունը։", "wer": 0.285714, "tokens": 7, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.285714}, - {"audio_filepath": "path3.wav", "duration": 6.984, "text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվազի մոտ երեք անգամ։", - "pred_text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվագի մոտ երեք անգամ։", "wer": 0.1, "tokens": 10, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.1}, - {"audio_filepath": "path4.wav", "duration": 4.104, "text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", - "pred_text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", "wer": 0.0, "tokens": 6, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.0} + { + "audio_filepath": "path1.wav", + "duration": 3.744, + "text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", + "pred_text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", + "wer": 0.142857, + "tokens": 7, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.142857, + }, + { + "audio_filepath": "path2.wav", + "duration": 5.76, + "text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմինդրների արտադրությունը։", + "pred_text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կաղնիտների արտադրությունը։", + "wer": 0.285714, + "tokens": 7, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.285714, + }, + { + "audio_filepath": "path3.wav", + "duration": 6.984, + "text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվազի մոտ երեք անգամ։", + "pred_text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվագի մոտ երեք անգամ։", + "wer": 0.1, + "tokens": 10, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.1, + }, + { + "audio_filepath": "path4.wav", + "duration": 4.104, + "text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", + "pred_text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", + "wer": 0.0, + "tokens": 6, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.0, + }, ] - + manifest2_data = [ - {"audio_filepath": "path1.wav", "duration": 3.744, "text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", - "pred_text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", "wer": 0.142857, "tokens": 7, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.142857}, - {"audio_filepath": "path2.wav", "duration": 5.76, "text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմինդրների արտադրությունը։", - "pred_text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմիտների արտադրությունը։", "wer": 0.285714, "tokens": 7, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.285714}, - {"audio_filepath": "path3.wav", "duration": 6.984, "text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվազի մոտ երեք անգամ։", - "pred_text": "Եթե մոտակայքում չկան մեղվափետներ, ապա բերքատվությունը կնվացի մոտ երեք անգամ։", "wer": 0.2, "tokens": 10, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.2}, - {"audio_filepath": "path4.wav", "duration": 4.104, "text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", - "pred_text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", "wer": 0.0, "tokens": 6, - "ins_rate": 0.0, "del_rate": 0.0, "sub_rate": 0.0} + { + "audio_filepath": "path1.wav", + "duration": 3.744, + "text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", + "pred_text": "Նա նաև լավ էր գրում մանկական ոտանավորներ։", + "wer": 0.142857, + "tokens": 7, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.142857, + }, + { + "audio_filepath": "path2.wav", + "duration": 5.76, + "text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմինդրների արտադրությունը։", + "pred_text": "Ամենամեծ ջանքերը պահանջեց աղյուսների և կղմիտների արտադրությունը։", + "wer": 0.285714, + "tokens": 7, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.285714, + }, + { + "audio_filepath": "path3.wav", + "duration": 6.984, + "text": "Եթե մոտակայքում չկան մեղվափեթակներ, ապա բերքատվությունը կնվազի մոտ երեք անգամ։", + "pred_text": "Եթե մոտակայքում չկան մեղվափետներ, ապա բերքատվությունը կնվացի մոտ երեք անգամ։", + "wer": 0.2, + "tokens": 10, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.2, + }, + { + "audio_filepath": "path4.wav", + "duration": 4.104, + "text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", + "pred_text": "Դպրոցը հիմնականում պահվել է եկեղեցու եկամուտներով։", + "wer": 0.0, + "tokens": 6, + "ins_rate": 0.0, + "del_rate": 0.0, + "sub_rate": 0.0, + }, ] # Expected output for comparison expected_output = { "individual_results": { - "manifest1.json": { - "mean_wer": 5.358, - "ci_lower": 0.5625, - "ci_upper": 10.992625 - }, - "manifest2.json": { - "mean_wer": 9.0725, - "ci_lower": 5.0, - "ci_upper": 15.234875 - } + "manifest1.json": {"mean_wer": 5.358, "ci_lower": 0.5625, "ci_upper": 10.992625}, + "manifest2.json": {"mean_wer": 9.0725, "ci_lower": 5.0, "ci_upper": 15.234875}, }, "pairwise_comparisons": [ { @@ -60,9 +119,9 @@ def test_bootstrap_processor(): "delta_wer_mean": -1.75, "ci_lower": -5.0, "ci_upper": 0.0, - "poi": 0.0 + "poi": 0.0, } - ] + ], } with tempfile.TemporaryDirectory() as tmpdir: @@ -89,7 +148,7 @@ def test_bootstrap_processor(): ci_lower=2.5, ci_upper=97.5, random_state=42, - output_manifest_file=None # A placeholder to skip BaseProcessor failing + output_manifest_file=None, # A placeholder to skip BaseProcessor failing ) processor.process() @@ -97,5 +156,5 @@ def test_bootstrap_processor(): # Load and compare the processor output with open(output_path, "r") as f: output = json.load(f) - + assert output == expected_output, f"Expected {expected_output}, but got {output}" diff --git a/tests/test_cfg_end_to_end_tests.py b/tests/test_cfg_end_to_end_tests.py index 05fdfcb3..1bb6427f 100644 --- a/tests/test_cfg_end_to_end_tests.py +++ b/tests/test_cfg_end_to_end_tests.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import logging import os import shutil import tarfile -import logging +from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Callable, List, Tuple from unittest import mock -from dataclasses import dataclass, field import pytest from omegaconf import OmegaConf @@ -31,9 +31,11 @@ DATASET_CONFIGS_ROOT = Path(__file__).parents[1] / "dataset_configs" + @dataclass class TestCase: """Class for keeping track of test cases.""" + config_path: str data_check_fn: Callable reference_manifest_filename: str = "test_data_reference.json" @@ -41,6 +43,7 @@ class TestCase: fields_to_ignore: List[str] = field(default_factory=list) processors_to_run: str = "" + def data_check_fn_generic(raw_data_dir: str, file_name: str, **kwargs) -> None: if callable(file_name): file_name = file_name(**kwargs) @@ -48,13 +51,19 @@ def data_check_fn_generic(raw_data_dir: str, file_name: str, **kwargs) -> None: if not expected_file.exists(): raise ValueError(f"No such file {str(expected_file)}") + data_check_fn_mls = partial(data_check_fn_generic, file_name=lambda language, **kwargs: f"mls_{language}.tar.gz") -data_check_fn_mcv = partial(data_check_fn_generic, file_name=lambda archive_file_stem, **kwargs: f"{archive_file_stem}.tar.gz") -data_check_fn_mtedx = partial(data_check_fn_generic, file_name=lambda language_id, **kwargs: f"mtedx_{language_id}.tgz") +data_check_fn_mcv = partial( + data_check_fn_generic, file_name=lambda archive_file_stem, **kwargs: f"{archive_file_stem}.tar.gz" +) +data_check_fn_mtedx = partial( + data_check_fn_generic, file_name=lambda language_id, **kwargs: f"mtedx_{language_id}.tgz" +) data_check_fn_coraa = partial(data_check_fn_generic, file_name="train_dividido/train.part1.rar") data_check_fn_librispeech = partial(data_check_fn_generic, file_name="dev-clean.tar.gz") data_check_fn_fleurs = partial(data_check_fn_generic, file_name="dev.tar.gz") + def data_check_fn_voxpopuli(raw_data_dir: str) -> None: """Raises error if do not find expected data. @@ -67,7 +76,8 @@ def data_check_fn_voxpopuli(raw_data_dir: str) -> None: raise ValueError(f"No such file {str(expected_file)}") with tarfile.open(expected_file, 'r:gz') as tar: tar.extractall(path=raw_data_dir) - + + def data_check_fn_slr140(raw_data_dir: str) -> None: """Raises error if do not find expected data. Will also extract the archive as initial processor expects extracted data. @@ -78,7 +88,8 @@ def data_check_fn_slr140(raw_data_dir: str) -> None: if not expected_file.exists(): raise ValueError(f"No such file {str(expected_file)}") - extract_tar_with_strip_components(expected_file, tgt_dir, strip_components=1) + extract_tar_with_strip_components(expected_file, tgt_dir, strip_components=1) + def data_check_fn_uzbekvoice(raw_data_dir: str) -> None: expected_files = [Path(raw_data_dir) / "clips.zip", Path(raw_data_dir) / "uzbekvoice-dataset.zip"] @@ -88,9 +99,10 @@ def data_check_fn_uzbekvoice(raw_data_dir: str) -> None: else: raise ValueError(f"No such file {str(expected_file)} at {str(raw_data_dir)}") + def data_check_fn_unlabeled(raw_data_dir: str) -> None: """Checks for data and sets it up for unlabeled processing. - + Args: raw_data_dir: Directory where data should be language: Language code (e.g. 'portuguese') @@ -104,24 +116,27 @@ def data_check_fn_unlabeled(raw_data_dir: str) -> None: with tarfile.open(expected_file, 'r:gz') as tar: tar.extractall(path=raw_data_dir) + def data_check_fn_armenian_toloka_pipeline_start(raw_data_dir: str) -> None: """Checks for the Armenian Toloka test data. - + For testing Toloka pipelines, we need a sample docx file to process. """ expected_dir = Path(raw_data_dir) / "pipeline_start" / "arm_docs" if not expected_dir.exists() or not any(expected_dir.glob("*.docx")): raise ValueError(f"No docx files found in {str(expected_dir)}") + def data_check_fn_armenian_toloka_pipeline_get_final_res(raw_data_dir: str) -> None: """Checks for the Armenian Toloka test data. - + For testing Toloka pipelines, we need a sample docx file to process. """ expected_dir = Path(raw_data_dir) / "pipeline_get_final_res" if not expected_dir.exists(): raise ValueError(f"Directory not found: {str(expected_dir)}") + # using Mock so coraal_processor will only try to use the files listed. # To reduce the amount of storage required by the test data, the S3 bucket contains # modified versions of LES_audio_part01_2021.07.tar.gz and @@ -135,146 +150,131 @@ def data_check_fn_armenian_toloka_pipeline_get_final_res(raw_data_dir: str) -> N ] ) + def get_test_cases() -> List[Tuple[str, Callable]]: return [ TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/spanish/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="spanish"), - ), - TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/spanish_pc/mcv12/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-12.0-2022-12-07-es") - ), - TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/italian/voxpopuli/config.yaml", - data_check_fn=data_check_fn_voxpopuli - ), + config_path=f"{DATASET_CONFIGS_ROOT}/spanish/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="spanish"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/italian/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="italian") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/spanish_pc/mcv12/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-12.0-2022-12-07-es"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="portuguese") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/italian/voxpopuli/config.yaml", data_check_fn=data_check_fn_voxpopuli + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-15.0-2023-09-08-pt") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/italian/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="italian"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mtedx/config.yaml", - data_check_fn=partial(data_check_fn_mtedx, language_id="pt") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="portuguese"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/coraa/config.yaml", - data_check_fn=data_check_fn_coraa - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-15.0-2023-09-08-pt"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/slr83/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mtedx/config.yaml", + data_check_fn=partial(data_check_fn_mtedx, language_id="pt"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/coraal/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/coraa/config.yaml", data_check_fn=data_check_fn_coraa + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/librispeech/config.yaml", - data_check_fn=data_check_fn_librispeech - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/slr83/config.yaml", data_check_fn=lambda raw_data_dir: True + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/coraal/config.yaml", data_check_fn=lambda raw_data_dir: True + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/text_mcv/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/librispeech/config.yaml", + data_check_fn=data_check_fn_librispeech, + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/audio_books/config.yaml", - data_check_fn=lambda raw_data_dir: True, - fields_to_ignore=['text'], - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/fleurs/config.yaml", data_check_fn=data_check_fn_fleurs + ), TestCase( - f"{DATASET_CONFIGS_ROOT}/kazakh/mcv/config.yaml", - partial(data_check_fn_mcv, archive_file_stem="mcv_kk") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/text_mcv/config.yaml", + data_check_fn=lambda raw_data_dir: True, + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr140/config.yaml", - data_check_fn=data_check_fn_slr140 - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/audio_books/config.yaml", + data_check_fn=lambda raw_data_dir: True, + fields_to_ignore=['text'], + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr102/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="slr102_kk.tar.gz") - ), + f"{DATASET_CONFIGS_ROOT}/kazakh/mcv/config.yaml", partial(data_check_fn_mcv, archive_file_stem="mcv_kk") + ), + TestCase(config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr140/config.yaml", data_check_fn=data_check_fn_slr140), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/ksc2/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="ksc2_kk.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr102/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="slr102_kk.tar.gz"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv_uz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/ksc2/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="ksc2_kk.tar.gz"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/uzbekvoice/config.yaml", - data_check_fn=data_check_fn_uzbekvoice - ), + config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv_uz"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/uzbekvoice/config.yaml", data_check_fn=data_check_fn_uzbekvoice + ), + TestCase(config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/fleurs/config.yaml", data_check_fn=data_check_fn_fleurs), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz") - ), - TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config_filter_noisy_train.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz"), - reference_manifest_filename="test_data_reference_filter.json" - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv.ar") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config_filter_noisy_train.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz"), + reference_manifest_filename="test_data_reference_filter.json", + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv.ar"), + ), + TestCase(config_path=f"{DATASET_CONFIGS_ROOT}/arabic/fleurs/config.yaml", data_check_fn=data_check_fn_fleurs), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mediaspeech/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="AR.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mediaspeech/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="AR.tar.gz"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/everyayah/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="everyayah.hf") + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/everyayah/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="everyayah.hf"), ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_start.yaml", - data_check_fn=data_check_fn_armenian_toloka_pipeline_start, - fields_to_ignore=['source_filepath'], - processors_to_run="2:14", - reference_manifest_filename="pipeline_start/test_data_reference.json" + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_start.yaml", + data_check_fn=data_check_fn_armenian_toloka_pipeline_start, + fields_to_ignore=['source_filepath'], + processors_to_run="2:14", + reference_manifest_filename="pipeline_start/test_data_reference.json", ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_get_final_res.yaml", - data_check_fn=data_check_fn_armenian_toloka_pipeline_get_final_res, - reference_manifest_filename="pipeline_get_final_res/test_data_reference.json", - fields_to_ignore=['audio_filepath', 'duration'], - processors_to_run="1:6" + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_get_final_res.yaml", + data_check_fn=data_check_fn_armenian_toloka_pipeline_get_final_res, + reference_manifest_filename="pipeline_get_final_res/test_data_reference.json", + fields_to_ignore=['audio_filepath', 'duration'], + processors_to_run="1:6", ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/unlabeled/config.yaml", + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/unlabeled/config.yaml", data_check_fn=partial(data_check_fn_unlabeled), fields_to_ignore=['duration'], - ), + ), TestCase( config_path=f"{DATASET_CONFIGS_ROOT}/english/hifitts2/config_22khz.yaml", data_check_fn=partial(data_check_fn_generic, file_name="manifest_22khz.json"), - processors_to_run="1:2" + processors_to_run="1:2", ), TestCase( config_path=f"{DATASET_CONFIGS_ROOT}/english/hifitts2/config_44khz.yaml", data_check_fn=partial(data_check_fn_generic, file_name="manifest_44khz.json"), - processors_to_run="1:2" + processors_to_run="1:2", ), TestCase( config_path=f"{DATASET_CONFIGS_ROOT}/english/hifitts2/config_bandwidth.yaml", @@ -283,10 +283,9 @@ def get_test_cases() -> List[Tuple[str, Callable]]: ), ] + def get_test_names(): - config_names = [ - Path(t.config_path).parent.relative_to(DATASET_CONFIGS_ROOT).as_posix() for t in get_test_cases() - ] + config_names = [Path(t.config_path).parent.relative_to(DATASET_CONFIGS_ROOT).as_posix() for t in get_test_cases()] return config_names @@ -299,14 +298,15 @@ def check_e2e_test_data() -> bool: """ return bool(os.getenv("TEST_DATA_ROOT") or (os.getenv("AWS_SECRET_KEY") and os.getenv("AWS_ACCESS_KEY"))) + def get_e2e_test_data_path(rel_path_from_root: str) -> str: """Returns path to e2e test data (downloading from AWS if necessary). In case of downloading from AWS, will create "test_data" folder in the current folder and set TEST_DATA_ROOT automatically (used by the sdp code to locate test data). """ - test_data_root = os.getenv("TEST_DATA_ROOT") - if test_data_root: # assume it's present locally + test_data_root = os.getenv("TEST_DATA_ROOT") + if test_data_root: # assume it's present locally return test_data_root import boto3 @@ -317,12 +317,12 @@ def get_e2e_test_data_path(rel_path_from_root: str) -> str: aws_secret_access_key=os.getenv("AWS_SECRET_KEY"), ) bucket = s3_resource.Bucket("sdp-test-data") - + logging.info(f"Downloading test data for {rel_path_from_root} from s3") for obj in bucket.objects.all(): if obj.key.endswith("/"): # do not try to "download_file" on objects which are actually directories continue - if rel_path_from_root in obj.key: + if rel_path_from_root in obj.key: if not os.path.exists(os.path.dirname(obj.key)): os.makedirs(os.path.dirname(obj.key)) bucket.download_file(obj.key, obj.key) @@ -330,17 +330,22 @@ def get_e2e_test_data_path(rel_path_from_root: str) -> str: return os.path.abspath("test_data") + @pytest.fixture(scope="module", params=get_test_cases(), ids=get_test_names()) def setup_data(request): if not check_e2e_test_data(): - pytest.fail("Either TEST_DATA_ROOT needs to be defined or both AWS_SECRET_KEY " - "and AWS_ACCESS_KEY to run e2e config tests") - - config_path, data_check_fn, reference_manifest_filename, fields_to_ignore, processors_to_run = (request.param.config_path, - request.param.data_check_fn, - request.param.reference_manifest_filename, - request.param.fields_to_ignore, - request.param.processors_to_run) + pytest.fail( + "Either TEST_DATA_ROOT needs to be defined or both AWS_SECRET_KEY " + "and AWS_ACCESS_KEY to run e2e config tests" + ) + + config_path, data_check_fn, reference_manifest_filename, fields_to_ignore, processors_to_run = ( + request.param.config_path, + request.param.data_check_fn, + request.param.reference_manifest_filename, + request.param.fields_to_ignore, + request.param.processors_to_run, + ) rel_path_from_root = Path(config_path).parent.relative_to(DATASET_CONFIGS_ROOT) test_data_root = get_e2e_test_data_path(str(rel_path_from_root)) @@ -363,6 +368,7 @@ def test_data_availability(setup_data): if not reference_manifest.exists(): pytest.fail(f"Reference manifest not found: {reference_manifest}") + @pytest.mark.dependency(depends=['test_data_availability']) def test_configs(setup_data, tmp_path): # we expect DATASET_CONFIGS_ROOT and TEST_DATA_ROOT @@ -379,7 +385,6 @@ def test_configs(setup_data, tmp_path): cfg.final_manifest = str(tmp_path / "final_manifest.json") cfg.data_split = cfg.get("data_split", "train") cfg.processors[0].raw_data_dir = data_dir.as_posix() - if "already_downloaded" in cfg["processors"][0]: cfg["processors"][0]["already_downloaded"] = True @@ -410,9 +415,10 @@ def test_configs(setup_data, tmp_path): # reference file (ignoring the file paths and additional fields explicitly specified to ignore) fields_to_ignore += ['audio_filepath'] - - with open(reference_manifest, "rt", encoding="utf8") as reference_fin, \ - open(cfg.final_manifest, "rt", encoding="utf8") as generated_fin: + + with open(reference_manifest, "rt", encoding="utf8") as reference_fin, open( + cfg.final_manifest, "rt", encoding="utf8" + ) as generated_fin: reference_lines = sorted(reference_fin.readlines()) generated_lines = sorted(generated_fin.readlines()) assert len(reference_lines) == len(generated_lines) @@ -428,6 +434,7 @@ def test_configs(setup_data, tmp_path): if os.getenv("CLEAN_UP_TMP_PATH", "0") != "0": shutil.rmtree(tmp_path) + # Additional unit tests to increase coverage def test_check_e2e_test_data(): os.environ.clear() @@ -439,6 +446,7 @@ def test_check_e2e_test_data(): os.environ["AWS_ACCESS_KEY"] = "access" assert check_e2e_test_data() + @pytest.mark.slow def test_get_e2e_test_data_path(tmp_path): os.environ["TEST_DATA_ROOT"] = str(tmp_path) @@ -458,5 +466,6 @@ def test_get_e2e_test_data_path(tmp_path): assert result == os.path.abspath("test_data") assert mock_bucket.download_file.call_count == 2 + if __name__ == "__main__": pytest.main([__file__, "-v", "--durations=0"]) diff --git a/tests/test_data_to_data.py b/tests/test_data_to_data.py index a18e40e8..2d04c1ce 100644 --- a/tests/test_data_to_data.py +++ b/tests/test_data_to_data.py @@ -16,11 +16,11 @@ from sdp.processors.modify_manifest.data_to_data import ( InsIfASRInsertion, + LambdaExpression, + ListToEntries, SubIfASRSubstitution, SubMakeLowercase, SubRegex, - ListToEntries, - LambdaExpression, ) test_params_list = [] @@ -98,15 +98,25 @@ ( ListToEntries, {"field_with_list": "segments"}, - {"audio_filepath": "a.wav", "segments": [{"start": 0.0, "end": 1.0, "text": "Hello"}, {"start": 1.1, "end": 2.0, "text": "World"}], "duration": 2.5}, - [{"audio_filepath": "a.wav", "duration": 2.5, "start": 0.0, "end": 1.0, "text": "Hello"}, {"audio_filepath": "a.wav", "duration": 2.5, "start": 1.1, "end": 2.0, "text": "World"}] + { + "audio_filepath": "a.wav", + "segments": [{"start": 0.0, "end": 1.0, "text": "Hello"}, {"start": 1.1, "end": 2.0, "text": "World"}], + "duration": 2.5, + }, + [ + {"audio_filepath": "a.wav", "duration": 2.5, "start": 0.0, "end": 1.0, "text": "Hello"}, + {"audio_filepath": "a.wav", "duration": 2.5, "start": 1.1, "end": 2.0, "text": "World"}, + ], ), # Test: list of primitive values (strings), requires output_field ( ListToEntries, {"field_with_list": "text_chunks", "output_field": "text"}, {"audio_filepath": "b.wav", "text_chunks": ["Привет", "Мир"], "lang": "ru"}, - [{"audio_filepath": "b.wav", "lang": "ru", "text": "Привет"}, {"audio_filepath": "b.wav", "lang": "ru", "text": "Мир"}] + [ + {"audio_filepath": "b.wav", "lang": "ru", "text": "Привет"}, + {"audio_filepath": "b.wav", "lang": "ru", "text": "Мир"}, + ], ), ] ) @@ -120,7 +130,6 @@ {"duration": 3.5}, [{"duration": 3.5, "duration_x2": 7.0}], ), - # Ternary expression ( LambdaExpression, @@ -128,7 +137,6 @@ {"duration": 12.0}, [{"duration": 12.0, "label": "long"}], ), - # Filtering: entry should be dropped (condition is False) ( LambdaExpression, @@ -136,7 +144,6 @@ {"duration": 5.0}, [], ), - # Filtering: entry should be kept (condition is True) ( LambdaExpression, @@ -144,7 +151,6 @@ {"duration": 12.0}, [{"duration": 12.0, "valid": True}], ), - # Using built-in function len() ( LambdaExpression, @@ -152,7 +158,6 @@ {"text": "hello world"}, [{"text": "hello world", "num_chars": 11}], ), - # Using built-in max() with sub-expressions ( LambdaExpression, @@ -160,7 +165,6 @@ {"a": 4, "b": 3}, [{"a": 4, "b": 3, "score": 6}], ), - # Expression using variable prefix (e.g., entry.a + entry.b) ( LambdaExpression, @@ -172,7 +176,6 @@ {"a": 1, "b": 2}, [{"a": 1, "b": 2, "sum": 3}], ), - # Logical expression using `and` ( LambdaExpression, @@ -183,7 +186,6 @@ {"a": 1, "b": 4}, [{"a": 1, "b": 4, "check": True}], ), - # Boolean expression without filtering (entry is always returned) ( LambdaExpression, @@ -203,4 +205,4 @@ def test_data_to_data(test_class, class_kwargs, test_input, expected_output): processor = test_class(**class_kwargs, output_manifest_file=None) result = [entry.data for entry in processor.process_dataset_entry(test_input)] - assert result == expected_output \ No newline at end of file + assert result == expected_output diff --git a/tests/test_import_manager.py b/tests/test_import_manager.py index bc79d3ea..1f126e53 100644 --- a/tests/test_import_manager.py +++ b/tests/test_import_manager.py @@ -1,13 +1,15 @@ -import tempfile -import os import json +import os +import tempfile from pathlib import Path -from sdp.utils.import_manager import ImportManager +from typing import Dict, List, Optional, Union + import pytest -from typing import Dict, List, Union, Optional + +from sdp.utils.import_manager import ImportManager # Example YAML content with processors -#Content is right, additional {} is needed because of the format function +# Content is right, additional {} is needed because of the format function TEST_YAML_CONTENT = """ use_import_manager: True processors_to_run: ":" # Run all processors @@ -27,21 +29,23 @@ # Example manifest content EXAMPLE_MANIFEST = [ {"id": 1, "text": "hello", "duration": 10, "audio_filepath": "path1"}, - {"id": 2, "text": "world", "duration": 12, "audio_filepath": "path2"} + {"id": 2, "text": "world", "duration": 12, "audio_filepath": "path2"}, ] + def _write_manifest(file_path, content: List[Dict]): """json lines to a file.""" with open(file_path, "w") as f: for entry in content: f.write(json.dumps(entry) + "\n") + def test_import_manager_with_workspace(): """ Test ImportManager's functionality with a workspace directory and example manifests. """ with tempfile.TemporaryDirectory() as tmp_workspace: - #workspace_dir = Path + # workspace_dir = Path workspace_dir = Path(tmp_workspace) # Step 1: example manifest files @@ -78,5 +82,3 @@ def test_import_manager_with_workspace(): assert test1_path.exists(), "test1.json should exist" assert not test2_path.exists(), "test2.json should not be overwritten yet" assert not test3_path.exists(), "test3.json should not be overwritten yet" - - \ No newline at end of file diff --git a/tests/test_lhotse.py b/tests/test_lhotse.py index bdb78348..5611585e 100644 --- a/tests/test_lhotse.py +++ b/tests/test_lhotse.py @@ -2,9 +2,8 @@ import pytest import soundfile - -from lhotse.testing.dummies import DummyManifest from lhotse import CutSet +from lhotse.testing.dummies import DummyManifest from sdp.processors.datasets.lhotse import LhotseImport @@ -35,9 +34,7 @@ def drop_custom(c): def test_lhotse_import(tmp_path, cuts_path): out_path = tmp_path / "nemo_manifest.json" - processor = LhotseImport( - input_manifest_file=cuts_path, output_manifest_file=out_path - ) + processor = LhotseImport(input_manifest_file=cuts_path, output_manifest_file=out_path) processor.process() EXPECTED_KEYS = { diff --git a/tests/test_manifest_chunking.py b/tests/test_manifest_chunking.py index ae1aa394..4bdbf64b 100644 --- a/tests/test_manifest_chunking.py +++ b/tests/test_manifest_chunking.py @@ -22,96 +22,91 @@ import pytest -from sdp.processors import DropNonAlphabet -from sdp.processors import SubMakeLowercase +from sdp.processors import DropNonAlphabet, SubMakeLowercase -def test_submakelowercase_with_chunking(tmp_path): - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "abc"}, - {"text": "def"}, - {"text": "ghi"}, - {"text": "jkl"}, - {"text": "mno"}, - {"text": "pqr"}, - {"text": "stu"}, - {"text": "vwx"}, - {"text": "yz"}, - ] - - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = SubMakeLowercase( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2 - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines +def test_submakelowercase_with_chunking(tmp_path): + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "abc"}, + {"text": "def"}, + {"text": "ghi"}, + {"text": "jkl"}, + {"text": "mno"}, + {"text": "pqr"}, + {"text": "stu"}, + {"text": "vwx"}, + {"text": "yz"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = SubMakeLowercase( + input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, in_memory_chunksize=2 + ) + + processor.process() + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines def test_dropnonalphabet_with_chunking(tmp_path): - - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "ABC"}, - ] - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = DropNonAlphabet( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2, - alphabet="ABC" - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "ABC"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = DropNonAlphabet( + input_manifest_file=input_manifest_file, + output_manifest_file=output_manifest_file, + in_memory_chunksize=2, + alphabet="ABC", + ) + + processor.process() + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines diff --git a/tests/test_tts_sdp_end_to_end.py b/tests/test_tts_sdp_end_to_end.py index c3517e95..c593bf77 100644 --- a/tests/test_tts_sdp_end_to_end.py +++ b/tests/test_tts_sdp_end_to_end.py @@ -1,37 +1,38 @@ -import pytest -import boto3 import json import os import tarfile from pathlib import Path + +import boto3 +import pytest from omegaconf import OmegaConf + from sdp.run_processors import run_processors from sdp.utils.common import load_manifest DATASET_CONFIGS_ROOT = Path(__file__).parents[1] / "dataset_configs" + @pytest.fixture def get_tts_ytc_data(tmpdir: str): # Download the data from S3 s3 = boto3.client( - 's3', - aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), - aws_secret_access_key=os.getenv("AWS_SECRET_KEY") + 's3', aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), aws_secret_access_key=os.getenv("AWS_SECRET_KEY") ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/test_data_reference.json", - tmpdir/"test_data_reference.json", + "sdp-test-data", + "test_data/tts/ytc/test_data_reference.json", + tmpdir / "test_data_reference.json", ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/ytc.en.tar.gz", - tmpdir/"ytc.en.tar.gz", + "sdp-test-data", + "test_data/tts/ytc/ytc.en.tar.gz", + tmpdir / "ytc.en.tar.gz", ) # Extract the tar.gz file - with tarfile.open(tmpdir/"ytc.en.tar.gz", "r:gz") as tar: + with tarfile.open(tmpdir / "ytc.en.tar.gz", "r:gz") as tar: tar.extractall(tmpdir) audio_files = Path(tmpdir).glob("audios/*") @@ -45,6 +46,7 @@ def get_tts_ytc_data(tmpdir: str): return tmpdir + def test_tts_sdp_end_to_end(get_tts_ytc_data): data_dir = get_tts_ytc_data assert os.path.exists(data_dir) @@ -71,13 +73,13 @@ def test_tts_sdp_end_to_end(get_tts_ytc_data): output_data = load_manifest(cfg.final_manifest, encoding="utf8") for item in output_data: output_file_data[item["audio_item_id"]] = item - + reference_file_data = {} reference_data = load_manifest(reference_manifest_file, encoding="utf8") for item in reference_data: reference_file_data[item["audio_item_id"]] = item - + assert len(output_file_data) == len(reference_file_data) assert len(output_file_data) == 2 for audio_item_id in output_file_data: - assert output_file_data[audio_item_id]["segments"] == reference_file_data[audio_item_id]["segments"] \ No newline at end of file + assert output_file_data[audio_item_id]["segments"] == reference_file_data[audio_item_id]["segments"]