diff --git a/dataset_configs/portuguese/unlabeled/config.yaml b/dataset_configs/portuguese/unlabeled/config.yaml new file mode 100644 index 00000000..b10e64ac --- /dev/null +++ b/dataset_configs/portuguese/unlabeled/config.yaml @@ -0,0 +1,108 @@ +documentation: | + Unlabeled Data Processing Pipeline + ################################## + + This pipeline processes unlabeled data for iterative pseudo-labeling training. + + The pipeline performs the following steps: + 1. Creates an initial manifest by searching for all WAV files in the `raw_data_dir` folder. + 2. Counts the duration of each WAV file. + 3. Identifies the language using the `langid_ambernet` NeMo model. + 4. Filters out audios that are tagged with a different language. + 5. Filters out audios that are too long to be processed. + 6. Applies the VAD algorithm from the NeMo repository. + 7. Forms segments by joining adjacent segments up to a duration threshold. + 8. Splits long audios into shorter segments. + 9. Removes empty files and extra fields from the manifest. + + **Required inputs**: + - `workspace_dir`: Directory for intermediate files, containing the following subfolders: + - `${workspace_dir}/wavs/` - Folder with source long files. + - `${workspace_dir}/sdp/` - Folder to store manifests. + - `${workspace_dir}/sdp/vad/` - Folder to store temporary files from the VAD algorithm. + - `${workspace_dir}/splited_wavs/` - Folder to store split short files. + + - `language_short`: Two-letter language code. + - `nemo_path`: Path to NeMo installation. + - `final_manifest`: Path to the final output manifest. + +processors_to_run: "0:" +workspace_dir: ??? +manifest_dir: ${workspace_dir}/sdp +language_short: pt +nemo_path: ??? +final_manifest: ${manifest_dir}/final_manifest.json + +processors: + - _target_: sdp.processors.CreateInitialManifestByExt + raw_data_dir: ${workspace_dir}/wavs + extension: wav + output_file_key: audio_filepath + output_manifest_file: ${manifest_dir}/manifest0.json + + - _target_: sdp.processors.GetAudioDuration + audio_filepath_key: audio_filepath + duration_key: duration + output_manifest_file: ${manifest_dir}/manifest1.json + + - _target_: sdp.processors.AudioLid + output_manifest_file: ${manifest_dir}/manifest2.json + input_audio_key: audio_filepath + output_lang_key: audio_lang + should_run: False + device: cuda + pretrained_model: "langid_ambernet" + segment_duration: 20 + num_segments: 3 + + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${manifest_dir}/manifest3.json + input_value_key: audio_lang + should_run: False + target_value: ${language_short} + + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${manifest_dir}/manifest4.json + input_value_key: duration + operator: le + target_value: 20000.0 + + - _target_: sdp.processors.Subprocess + cmd: 'rm -rf ${manifest_dir}/vad/*' + + - _target_: sdp.processors.Subprocess + input_manifest_file: ${manifest_dir}/manifest4.json + output_manifest_file: ${manifest_dir}/vad + input_manifest_arg: "manifest_filepath" + output_manifest_arg: "output_dir" + cmd: 'python sdp/processors/nemo/speech_to_text_with_vad.py audio_type=wav vad_model=vad_multilingual_frame_marblenet vad_config=sdp/processors/nemo/frame_vad_infer_postprocess.yaml' + + - _target_: sdp.processors.RenameFields + input_manifest_file: ${manifest_dir}/vad/temp_manifest_vad_rttm-onset0.3-offset0.3-pad_onset0.2-pad_offset0.2-min_duration_on0.2-min_duration_off0.2-filter_speech_firstTrue.json + output_manifest_file: ${manifest_dir}/manifest7.json + rename_fields: {"audio_filepath":"source_filepath"} + + - _target_: sdp.processors.nemo.rttm.GetRttmSegments + output_manifest_file: ${manifest_dir}/manifest8.json + rttm_key: rttm_file + output_file_key: audio_segments + duration_key: duration + duration_threshold: 20.0 + + - _target_: sdp.processors.nemo.rttm.SplitAudioFile + output_manifest_file: ${manifest_dir}/manifest9.json + splited_audio_dir: ${workspace_dir}/splited_wavs/ + segments_key: audio_segments + duration_key: duration + input_file_key: source_filepath + output_file_key: audio_filepath + + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${manifest_dir}/manifest10.json + input_value_key: duration + operator: gt + target_value: 0.0 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + output_manifest_file: ${final_manifest} + fields_to_keep: ["audio_filepath", "duration"] \ No newline at end of file diff --git a/docs/src/sdp/existing_configs.rst b/docs/src/sdp/existing_configs.rst index 37922108..d151a848 100644 --- a/docs/src/sdp/existing_configs.rst +++ b/docs/src/sdp/existing_configs.rst @@ -408,8 +408,19 @@ HiFiTTS-2 config-docs/english/hifitts2/config_44khz config-docs/english/hifitts2/config_bandwidth + +Unlabeled Portuguese Data +~~~~~~~~~~~~~~~~~~~~~~~~~ + +`config `__ | +:doc:`documentation ` + +.. toctree:: + :hidden: + + config-docs/portuguese/unlabeled/config + NemoRunIPL -~~~~~~~~~~ **Supported configs**. @@ -419,13 +430,13 @@ NemoRunIPL * **NeMoRun**: `config `__ | :doc:`documentation ` - + .. toctree:: :hidden: config-docs/ipl/config config-docs/ipl/nemo_run_config - + Earnings21/22 ~~~~~~~~~~~~~ @@ -438,4 +449,4 @@ Earnings21/22 .. toctree:: :hidden: - config-docs/english/earnings/config \ No newline at end of file + config-docs/english/earnings/config diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 9e363592..ab0d05ed 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -32,9 +32,8 @@ CreateInitialManifestFleurs, ) from sdp.processors.datasets.hifitts2.download_dataset import DownloadHiFiTTS2 -from sdp.processors.datasets.hifitts2.remove_failed_chapters import RemovedFailedChapters -from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( - CreateInitialManifestUzbekvoice, +from sdp.processors.datasets.hifitts2.remove_failed_chapters import ( + RemovedFailedChapters, ) from sdp.processors.datasets.ksc2.create_initial_manifest import ( CreateInitialManifestKSC2, @@ -44,13 +43,15 @@ CreateInitialManifestLibrispeech, ) from sdp.processors.datasets.masc import ( - CreateInitialManifestMASC, AggregateSegments, + CreateInitialManifestMASC, + GetCaptionFileSegments, RegExpVttEntries, - GetCaptionFileSegments ) -from sdp.processors.datasets.mediaspeech.create_initial_manifest import CreateInitialManifestMediaSpeech from sdp.processors.datasets.mcv.create_initial_manifest import CreateInitialManifestMCV +from sdp.processors.datasets.mediaspeech.create_initial_manifest import ( + CreateInitialManifestMediaSpeech, +) from sdp.processors.datasets.mls.create_initial_manifest import CreateInitialManifestMLS from sdp.processors.datasets.mls.restore_pc import RestorePCForMLS from sdp.processors.datasets.mtedx.create_initial_manifest import ( @@ -67,18 +68,20 @@ CreateInitialManifestSLR140, CustomDataSplitSLR140, ) +from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( + CreateInitialManifestUzbekvoice, +) from sdp.processors.datasets.voxpopuli.create_initial_manifest import ( CreateInitialManifestVoxpopuli, ) from sdp.processors.datasets.voxpopuli.normalize_from_non_pc_text import ( NormalizeFromNonPCTextVoxpopuli, ) -from sdp.processors.datasets.ytc.create_initial_manifest import ( - CreateInitialManifestYTC, +from sdp.processors.datasets.ytc.create_initial_manifest import CreateInitialManifestYTC +from sdp.processors.huggingface.create_initial_manifest import ( + CreateInitialManifestHuggingFace, ) from sdp.processors.huggingface.speech_recognition import ASRTransformers -from sdp.processors.huggingface.create_initial_manifest import CreateInitialManifestHuggingFace - from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, @@ -89,7 +92,9 @@ RenameFields, SortManifest, SplitOnFixedDuration, + Subprocess, DropSpecifiedFields, + ) from sdp.processors.modify_manifest.create_manifest import ( CreateCombinedManifests, @@ -104,8 +109,8 @@ GetWER, InsIfASRInsertion, InverseNormalizeText, - NormalizeText, MakeSentence, + NormalizeText, ReadDocxLines, ReadTxtLines, SplitLineBySentence, @@ -130,8 +135,8 @@ DropLowWordMatchRate, DropNonAlphabet, DropOnAttribute, - PreserveByValue, DropRepeatedFields, + PreserveByValue, ) from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, @@ -148,6 +153,7 @@ ) from sdp.processors.nemo.asr_inference import ASRInference from sdp.processors.nemo.estimate_bandwidth import EstimateBandwidth +from sdp.processors.nemo.lid_inference import AudioLid from sdp.processors.nemo.pc_inference import PCInference from sdp.processors.toloka.accept_if import AcceptIfWERLess from sdp.processors.toloka.create_pool import CreateTolokaPool diff --git a/sdp/processors/modify_manifest/common.py b/sdp/processors/modify_manifest/common.py index ea2fdf67..c94c72bb 100644 --- a/sdp/processors/modify_manifest/common.py +++ b/sdp/processors/modify_manifest/common.py @@ -14,12 +14,14 @@ import json import os +import subprocess from pathlib import Path -from typing import Dict, List, Union, Optional +from typing import Dict, List, Optional, Union import pandas as pd from tqdm import tqdm +from sdp.logging import logger from sdp.processors.base_processor import ( BaseParallelProcessor, BaseProcessor, @@ -28,6 +30,71 @@ ) from sdp.utils.common import load_manifest + +class Subprocess(BaseProcessor): + """ + Processor for handling subprocess execution with additional features for managing input and output manifests. + + Args: + cmd (str): The command to be executed as a subprocess. + input_manifest_arg (str, optional): The argument specifying the input manifest. Defaults to an empty string. + output_manifest_arg (str, optional): The argument specifying the output manifest. Defaults to an empty string. + arg_separator (str, optional): The separator used between argument and value. Defaults to "=". + **kwargs: Additional keyword arguments to be passed to the base class. + + Example: + + _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: /workspace/manifest.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /workspace/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=en --n_jobs=-1 --batch_size=600 --manifest_text_field=text --cache_dir=${workspace_dir}/cache --overwrite_cache \ + --whitelist=/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/en/data/whitelist/asr_with_pc.tsv" + + """ + + def __init__( + self, + cmd: str, + input_manifest_arg: str = "", + output_manifest_arg: str = "", + arg_separator: str = "=", + **kwargs, + ): + super().__init__(**kwargs) + self.input_manifest_arg = input_manifest_arg + self.output_manifest_arg = output_manifest_arg + self.arg_separator = arg_separator + self.cmd = cmd + + def process(self): + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + if self.cmd.find(self.input_manifest_file) != -1 or self.cmd.find(self.output_manifest_file) != -1: + logger.error( + "input_manifest_file " + + self.input_manifest_file + + " and output_manifest_file " + + self.output_manifest_file + + " should be exluded from cmd line!" + ) + raise ValueError + process_args = [x for x in self.cmd.split(" ") if x] + if self.arg_separator == " ": + if self.input_manifest_arg: + process_args.extend([self.input_manifest_arg, self.input_manifest_file]) + if self.output_manifest_arg: + process_args.extend([self.output_manifest_arg, self.output_manifest_file]) + else: + if self.input_manifest_arg: + process_args.extend([self.input_manifest_arg + self.arg_separator + self.input_manifest_file]) + if self.output_manifest_arg: + process_args.extend([self.output_manifest_arg + self.arg_separator + self.output_manifest_file]) + subprocess.run(" ".join(process_args), shell=True) + + + class CombineSources(BaseParallelProcessor): """Can be used to create a single field from two alternative sources. @@ -104,24 +171,24 @@ class AddConstantFields(BaseParallelProcessor): This processor adds constant fields to all manifest entries using Dask BaseParallelProcessor. It is useful when you want to attach fixed information (e.g., a language label or metadata) to each entry for downstream tasks such as language identification model training. - + Args: fields (dict): A dictionary containing key-value pairs of fields to add to each manifest entry. For example:: - + { "label": "en", "metadata": "mcv-11.0-2022-09-21" } - + Returns: dict: The same data as in the input manifest with the added constant fields as specified in the ``fields`` dictionary. - + Example: - + .. code-block:: yaml - + - _target_: sdp.processors.modify_manifest.common.AddConstantFields input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json @@ -139,7 +206,6 @@ def process_dataset_entry(self, data_entry: Dict): return [DataEntry(data=data_entry)] - class DuplicateFields(BaseParallelProcessor): """This processor duplicates fields in all manifest entries. @@ -154,8 +220,8 @@ class DuplicateFields(BaseParallelProcessor): Returns: The same data as in the input manifest with duplicated fields - as specified in the ``duplicate_fields`` input dictionary. - + as specified in the ``duplicate_fields`` input dictionary. + Example: .. code-block:: yaml @@ -165,6 +231,7 @@ class DuplicateFields(BaseParallelProcessor): duplicate_fields: {"text":"answer"} """ + def __init__( self, duplicate_fields: Dict, diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index 1e416571..1248f3b0 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -17,10 +17,7 @@ import pandas -from sdp.processors.base_processor import ( - BaseParallelProcessor, - DataEntry, -) +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry class CreateInitialManifestByExt(BaseParallelProcessor): @@ -61,17 +58,18 @@ def process_dataset_entry(self, data_entry): class CreateCombinedManifests(BaseParallelProcessor): """Reads JSON lines from specified files and creates a combined manifest. - This processor iterates over files listed in `manifest_list`, reads each file line by line, + This processor iterates over files listed in `manifest_list`, reads each file line by line, and yields the parsed JSON data from each line. Args: - manifest_list (list(str)): A list of file paths or directories to process. The processor will + manifest_list (list(str)): A list of file paths or directories to process. The processor will recursively read files within the directories and expect each file to contain JSON data. **kwargs: Additional keyword arguments passed to the base class `BaseParallelProcessor`. Returns: A generator that yields parsed JSON data from each line in the files listed in `manifest_list`. """ + def __init__( self, manifest_list: list[str], @@ -88,6 +86,3 @@ def read_manifest(self): def process_dataset_entry(self, data_entry): return [DataEntry(data=data_entry)] - - - diff --git a/sdp/processors/modify_manifest/data_to_dropbool.py b/sdp/processors/modify_manifest/data_to_dropbool.py index ff675e0a..eeeebd1e 100644 --- a/sdp/processors/modify_manifest/data_to_dropbool.py +++ b/sdp/processors/modify_manifest/data_to_dropbool.py @@ -14,9 +14,8 @@ import collections import json +import os import re -import os -import json from operator import eq, ge, gt, le, lt, ne from typing import List, Union @@ -76,7 +75,7 @@ def __init__( 'Operator must be one from the list: "lt" (less than), "le" (less than or equal to), "eq" (equal to), "ne" (not equal to), "ge" (greater than or equal to), "gt" (greater than)' ) - def process_dataset_entry(self, data_entry): + def process_dataset_entry(self, data_entry): input_value = data_entry[self.input_value_key] target = self.target_value if self.operator(input_value, target): @@ -808,9 +807,9 @@ class DropRepeatedFields(BaseParallelProcessor): """Drops utterances from the current manifest if their text fields are present in other manifests. This class processes multiple manifest files and removes entries from the current manifest if the text field - matches any entry in the other manifests. It allows for optional punctuation removal from the text fields + matches any entry in the other manifests. It allows for optional punctuation removal from the text fields before performing the check. - + .. note:: It is better to process Test/Dev/Train and then Other.tsv @@ -819,19 +818,21 @@ class DropRepeatedFields(BaseParallelProcessor): current_manifest_file (str): Path to the current manifest file to be processed. punctuations (str): (Optional): String of punctuation characters to be removed from the text fields before checking for duplicates. Defaults to None. text_key (str): The key in the manifest entries that contains the text field. Defaults to "text". - + Returns: The same data as in the input manifest with some entries dropped. """ - def __init__(self, - manifests_paths: List[str], - current_manifest_file: str, - punctuations: str = None, - text_key: str = "text", - **kwargs - ): - super().__init__( **kwargs) + + def __init__( + self, + manifests_paths: List[str], + current_manifest_file: str, + punctuations: str = None, + text_key: str = "text", + **kwargs, + ): + super().__init__(**kwargs) self.manifests_paths = manifests_paths self.current_manifest_file = current_manifest_file self.text_key = text_key @@ -851,10 +852,10 @@ def load_data(self): if self.punctuations is not None and len(self.punctuations) > 0: line_text = self.remove_punctuation(line_text) self.text_set.add(line_text) - + def remove_punctuation(self, text): return re.sub(fr'[{self.punctuations}]', '', text) - + def process_dataset_entry(self, data_entry) -> List: text_for_check = data_entry[self.text_key] if self.punctuations is not None and len(self.punctuations) > 0: @@ -862,7 +863,7 @@ def process_dataset_entry(self, data_entry) -> List: if text_for_check in self.text_set: return [DataEntry(data=None, metrics=1)] return [DataEntry(data=data_entry, metrics=0)] - + def finalize(self, metrics: List): total_counter = 0 for counter in metrics: diff --git a/sdp/processors/nemo/frame_vad_infer_postprocess.yaml b/sdp/processors/nemo/frame_vad_infer_postprocess.yaml new file mode 100644 index 00000000..1d00eca6 --- /dev/null +++ b/sdp/processors/nemo/frame_vad_infer_postprocess.yaml @@ -0,0 +1,39 @@ +name: &name "vad_inference_postprocessing" + +input_manifest: null # Path of json file of evaluation data. Audio files should have unique names +output_dir: null # Path to output directory where results will be stored +num_workers: 12 +sample_rate: 16000 +evaluate: false # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled + +prepare_manifest: + auto_split: true # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue. + split_duration: 400 # max length in seconds, try smaller number if you still have CUDA memory issue + +vad: + model_path: "vad_multilingual_frame_marblenet" #.nemo local model path or pretrained model name or none + use_rttm: True # set True to output as RTTM format + parameters: # Parameters not tuned on large datasets, please use default parameters with caution + normalize_audio_db: null # set to non null value to normalize RMS DB of audio before preprocessing + window_length_in_sec: 0.0 # window length in sec for VAD context input, must be 0 for frame-VAD + shift_length_in_sec: 0.02 # frame-length in seconds for frame-VAD, must be 0.02 for the pretrained NeMo VAD model + smoothing: False # Deprecated for Frame-VAD. false or type of smoothing method (eg: median, mean) + overlap: 0.875 # Deprecated for Frame-VAD. overlap ratio for overlapped mean/median smoothing filter. If smoothing=False, ignore this value. + postprocessing: + onset: 0.3 # onset threshold for detecting the beginning and end of a speech + offset: 0.3 # offset threshold for detecting the end of a speech. + pad_onset: 0.2 # adding durations before each speech segment + pad_offset: 0.2 # adding durations after each speech segment + min_duration_on: 0.2 # threshold for short speech deletion + min_duration_off: 0.2 # threshold for short non-speech segment deletion + filter_speech_first: True + +prepared_manifest_vad_input: null # if not specify, it will automatically generated be "manifest_vad_input.json" +frame_out_dir: "vad_frame_outputs" +smoothing_out_dir: null # if not specify, it will automatically generated be frame_out_dir + "/overlap_smoothing_output" + "_" + smoothing_method + "_" + str(overlap) +rttm_out_dir: null # if not specify, it will automatically be frame_out_dir + "/seg_output_" + key and value in postprocessing params +out_manifest_filepath: null # if not specify it will automatically be "manifest_vad_out.json" + + +# json manifest line example +# {"audio_filepath": "/path/to/audio_file.wav", "offset": 0, "duration": 1.23, "label": "infer", "text": "-"} diff --git a/sdp/processors/nemo/lid_inference.py b/sdp/processors/nemo/lid_inference.py new file mode 100644 index 00000000..0f1b871f --- /dev/null +++ b/sdp/processors/nemo/lid_inference.py @@ -0,0 +1,77 @@ +import json +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +from sdp.logging import logger +from sdp.processors.base_processor import BaseProcessor +from sdp.utils.common import load_manifest + + +class AudioLid(BaseProcessor): + """ + Processor for language identification (LID) of audio files using a pre-trained LID model. + + Args: + input_audio_key (str): The key in the dataset containing the path to the audio files for language identification. + pretrained_model (str): The name of the pre-trained ASR model for language identification. + output_lang_key (str): The key to store the identified language for each audio file. + device (str): The device to run the ASR model on (e.g., 'cuda', 'cpu'). If None, it automatically selects the available GPU if present; otherwise, it uses the CPU. + segment_duration (float): Random sample duration in seconds. Delault is np.inf. + num_segments (int): Number of segments of file to use for majority vote. Delault is 1. + random_seed (int): Seed for generating the starting position of the segment. Delault is None. + **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. + + """ + + def __init__( + self, + input_audio_key: str, + pretrained_model: str, + output_lang_key: str, + device: str, + segment_duration: float = np.inf, + num_segments: int = 1, + random_seed: int = None, + **kwargs, + ): + super().__init__(**kwargs) + self.input_audio_key = input_audio_key + self.pretrained_model = pretrained_model + self.output_lang_key = output_lang_key + self.segment_duration = segment_duration + self.num_segments = num_segments + self.random_seed = random_seed + self.device = device + + def process(self): + import nemo.collections.asr as nemo_asr + import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo + + model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name=self.pretrained_model) + + if self.device is None: + if torch.cuda.is_available(): + model = model.cuda() + else: + model = model.cpu() + else: + model = model.to(self.device) + + manifest = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(manifest): + audio_file = item[self.input_audio_key] + + try: + lang = model.get_label(audio_file, self.segment_duration, self.num_segments) + except Exception as e: + logger.warning("AudioLid " + audio_file + " " + str(e)) + lang = None + + if lang: + item[self.output_lang_key] = lang + f.write(json.dumps(item, ensure_ascii=False) + '\n') diff --git a/sdp/processors/nemo/rttm.py b/sdp/processors/nemo/rttm.py new file mode 100644 index 00000000..394014ca --- /dev/null +++ b/sdp/processors/nemo/rttm.py @@ -0,0 +1,171 @@ +import os +from typing import Dict, List, Union + +import soundfile as sf +from tqdm import tqdm + +from sdp.logging import logger +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +from sdp.utils.common import load_manifest + + +class GetRttmSegments(BaseParallelProcessor): + """This processor extracts audio segments based on RTTM (Rich Transcription Time Marked) files. + + The class reads an RTTM file specified by the `rttm_key` in the input data entry and + generates a list of audio segment start times. It ensures that segments longer than a specified + duration threshold are split into smaller segments. The resulting segments are stored in the + output data entry under the `output_file_key`. + + Args: + rttm_key (str): The key in the manifest that contains the path to the RTTM file. + output_file_key (str, optional): The key in the data entry where the list of audio segment + start times will be stored. Defaults to "audio_segments". + duration_key (str, optional): The key in the data entry that contains the total duration + of the audio file. Defaults to "duration". + duration_threshold (float, optional): The maximum duration for a segment before it is split. + Segments longer than this threshold will be divided into smaller segments. Defaults to 20.0 seconds. + + Returns: + A list containing a single `DataEntry` object with the updated data entry, which includes + the `output_file_key` containing the sorted list of audio segment start times. + """ + + def __init__( + self, + rttm_key: str, + output_file_key: str = "audio_segments", + duration_key: str = "duration", + duration_threshold: float = 20.0, + **kwargs, + ): + super().__init__(**kwargs) + self.rttm_key = rttm_key + self.duration_threshold = duration_threshold + self.duration_key = duration_key + self.output_file_key = output_file_key + + def split_long_segment(self, slices, duration, last_slice): + duration0 = self.duration_threshold + while duration0 < duration: + slices.append(last_slice + duration0) + duration0 += self.duration_threshold + if duration0 > duration: + duration0 = duration + slices.append(last_slice + duration0) + return slices, last_slice + duration0 + + def process_dataset_entry(self, data_entry: Dict): + file_duration = data_entry[self.duration_key] + rttm_file = data_entry[self.rttm_key] + + starts = [] + with open(rttm_file, "r") as f: + for line in f: + starts.append(float(line.split(" ")[3])) + starts.append(file_duration) + + slices = [0] + last_slice, last_start, last_duration, duration = 0, 0, 0, 0 + for start in starts: + duration = start - last_slice + + if duration <= self.duration_threshold: + pass + elif duration > self.duration_threshold and last_duration < self.duration_threshold: + slices.append(last_start) + last_slice = last_start + last_start = start + last_duration = duration + duration = start - last_slice + if duration <= self.duration_threshold: + slices.append(start) + last_slice = start + else: + slices, last_slice = self.split_long_segment(slices, duration, last_slice) + + else: + slices.append(start) + last_slice = start + last_start = start + last_duration = duration + + data_entry[self.output_file_key] = sorted(list(set(slices))) + + return [DataEntry(data=data_entry)] + + +class SplitAudioFile(BaseParallelProcessor): + """This processor splits audio files into segments based on provided timestamps. + + The class reads an audio file specified by the `input_file_key` and splits it into segments + based on the timestamps provided in the `segments_key` field of the input data entry. + The split audio segments are saved as individual WAV files in the specified `splited_audio_dir` + directory. The `output_file_key` field of the data entry is updated with the path to the + corresponding split audio file, and the `duration_key` field is updated with the duration + of the split audio segment. + + Args: + splited_audio_dir (str): The directory where the split audio files will be saved. + segments_key (str, optional): The key in the manifest that contains the list of + timestamps for splitting the audio. Defaults to "audio_segments". + duration_key (str, optional): The key in the manifest where the duration of the + split audio segment will be stored. Defaults to "duration". + input_file_key (str, optional): The key in the manifest that contains the path + to the input audio file. Defaults to "source_filepath". + output_file_key (str, optional): The key in the manifest where the path to the + split audio file will be stored. Defaults to "audio_filepath". + + Returns: + A list of data entries, where each entry represents a split audio segment with + the corresponding file path and duration updated in the data entry. + """ + + def __init__( + self, + splited_audio_dir: str, + segments_key: str = "audio_segments", + duration_key: str = "duration", + input_file_key: str = "source_filepath", + output_file_key: str = "audio_filepath", + **kwargs, + ): + super().__init__(**kwargs) + self.splited_audio_dir = splited_audio_dir + self.segments_key = segments_key + self.duration_key = duration_key + self.input_file_key = input_file_key + self.output_file_key = output_file_key + + def write_segment(self, data, samplerate, start_sec, end_sec, input_file): + wav_save_file = os.path.join( + self.splited_audio_dir, + os.path.splitext(os.path.split(input_file)[1])[0], + str(int(start_sec * 100)) + "-" + str(int(end_sec * 100)) + ".wav", + ) + if not os.path.isfile(wav_save_file): + data_sample = data[int(start_sec * samplerate) : int(end_sec * samplerate)] + duration = len(data_sample) / samplerate + os.makedirs(os.path.split(wav_save_file)[0], exist_ok=True) + sf.write(wav_save_file, data_sample, samplerate) + return wav_save_file, duration + else: + try: + data, samplerate = sf.read(wav_save_file) + duration = data.shape[0] / samplerate + except Exception as e: + logger.warning(str(e) + " file: " + wav_save_file) + duration = -1.0 + return wav_save_file, duration + + def process_dataset_entry(self, data_entry: Dict): + slices = data_entry[self.segments_key] + input_file = data_entry[self.input_file_key] + input_data, samplerate = sf.read(input_file) + data_entries = [] + for i in range(len(slices[:-1])): + wav_save_file, duration = self.write_segment(input_data, samplerate, slices[i], slices[i + 1], input_file) + data_entry[self.output_file_key] = wav_save_file + data_entry[self.duration_key] = duration + data_entries.append(DataEntry(data=data_entry.copy())) + return data_entries diff --git a/sdp/processors/nemo/speech_to_text_with_vad.py b/sdp/processors/nemo/speech_to_text_with_vad.py new file mode 100644 index 00000000..6fdd183d --- /dev/null +++ b/sdp/processors/nemo/speech_to_text_with_vad.py @@ -0,0 +1,649 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file provides the ASR+VAD inference pipeline, with the option to perform only ASR or VAD alone. + +There are two types of input, the first one is a manifest passed to `manifest_filepath`, +and the other one is to pass a directory containing audios to `audio_dir` and specify `audio_type`. + +The input manifest must be a manifest json file, where each line is a Python dictionary. The fields ["audio_filepath", "offset", "duration", "text"] are required. An example of a manifest file is: +``` +{"audio_filepath": "/path/to/audio_file1", "offset": 0, "duration": 10000, "text": "a b c d e"} +{"audio_filepath": "/path/to/audio_file2", "offset": 0, "duration": 10000, "text": "f g h i j"} +``` + +To run the code with ASR+VAD default settings: + +```bash +python speech_to_text_with_vad.py \ + manifest_filepath=/PATH/TO/MANIFEST.json \ + vad_model=vad_multilingual_frame_marblenet\ + asr_model=stt_en_conformer_ctc_large \ + vad_config=../conf/vad/frame_vad_inference_postprocess.yaml +``` + +To use only ASR and disable VAD, set `vad_model=None` and `use_rttm=False`. + +To use only VAD, set `asr_model=None` and specify both `vad_model` and `vad_config`. + +To enable profiling, set `profiling=True`, but this will significantly slow down the program. + +To use or disable feature masking/droping based on RTTM files, set `use_rttm` to `True` or `False`. +There are two ways to use RTTM files, either by masking the features (`rttm_mode=mask`) or by dropping the features (`rttm_mode=drop`). +For audios that have long non-speech audios between speech segments, dropping frames is recommended. + +To normalize feature before masking, set `normalize=pre_norm`, +and set `normalize=post_norm` for masking before normalization. + +To use a specific value for feature masking, set `feat_mask_val` to the desired value. +Default is `feat_mask_val=None`, where -16.635 will be used for `post_norm` and 0 will be used for `pre_norm`. + +See more options in the `InferenceConfig` class. +""" + + +import contextlib +import json +import os + +import time +from dataclasses import dataclass, is_dataclass, field +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.amp +import yaml +from omegaconf import DictConfig, OmegaConf +from torch.profiler import ProfilerActivity, profile, record_function +from tqdm import tqdm + +from nemo.collections.asr.data import feature_to_text_dataset +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.models import ASRModel, EncDecClassificationModel +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.asr.parts.utils.vad_utils import ( + generate_overlap_vad_seq, + generate_vad_segment_table, + get_vad_stream_status, + init_frame_vad_model, + init_vad_model, +) +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@dataclass +class InferenceConfig: + # Required configs + asr_model: Optional[str] = None # Path to a .nemo file or a pretrained NeMo model on NGC + vad_model: Optional[str] = None # Path to a .nemo file or a pretrained NeMo model on NGC + vad_config: Optional[str] = None # Path to a yaml file containing VAD post-processing configs + manifest_filepath: Optional[str] = None # Path to dataset's JSON manifest + audio_dir: Optional[str] = None # Path to a directory containing audio files, use this if no manifest is provided + + use_rttm: bool = True # whether to use RTTM + rttm_mode: str = "mask" # how to use RTTM files, choices=[`mask`, `drop`] + feat_mask_val: Optional[float] = None # value used to mask features based on RTTM, set None to use defaults + normalize: Optional[str] = ( + "post_norm" # whether and where to normalize audio feature, choices=[None, `pre_norm`, `post_norm`] + ) + normalize_type: str = "per_feature" # how to determine mean and std used for normalization + normalize_audio_db: Optional[float] = None # set to normalize RMS DB of audio before extracting audio features + + profiling: bool = False # whether to enable pytorch profiling + + # General configs + batch_size: int = 1 # batch size for ASR. Feature extraction and VAD only support single sample per batch. + num_workers: int = 8 + sample_rate: int = 16000 + frame_unit_time_secs: float = ( + 0.01 # unit time per frame in seconds, equal to `window_stride` in ASR configs, typically 10ms. + ) + audio_type: str = "wav" + + # Output settings, no need to change + output_dir: Optional[str] = None # will be automatically set by the program + output_filename: Optional[str] = None # will be automatically set by the program + pred_name_postfix: Optional[str] = None # If you need to use another model name, other than the standard one. + + # Set to True to output language ID information + compute_langs: bool = False + + # Decoding strategy for CTC models + ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) + + # Decoding strategy for RNNT models + rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) + + # VAD model type + vad_type: str = "frame" # which type of VAD to use, choices=[`frame`, `segment`] + + +@hydra_runner(config_name="InferenceConfig", schema=InferenceConfig) +def main(cfg): + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.output_dir is None: + cfg.output_dir = "./outputs" + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + # setup profiling, note that profiling will significantly increast the total runtime + if cfg.profiling: + logging.info("Profiling enabled") + profile_fn = profile + record_fn = record_function + else: + logging.info("Profiling disabled") + + @contextlib.contextmanager + def profile_fn(*args, **kwargs): + yield + + @contextlib.contextmanager + def record_fn(*args, **kwargs): + yield + + input_manifest_file = prepare_inference_manifest(cfg) + + if cfg.manifest_filepath is None: + cfg.manifest_filepath = str(input_manifest_file) + + with profile_fn( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True + ) as prof: + + input_manifest_file = extract_audio_features(input_manifest_file, cfg, record_fn) + + if cfg.vad_model is not None: + logging.info(f"Running VAD with model: {cfg.vad_model}") + input_manifest_file = run_vad_inference(input_manifest_file, cfg, record_fn) + + if cfg.asr_model is not None: + logging.info(f"Running ASR with model: {cfg.asr_model}") + run_asr_inference(input_manifest_file, cfg, record_fn) + + if cfg.profiling: + print("--------------------------------------------------------------------\n") + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15)) + print("--------------------------------------------------------------------\n") + logging.info("Done.") + + +def prepare_inference_manifest(cfg: DictConfig) -> str: + + if cfg.audio_dir is not None and cfg.manifest_filepath is None: + manifest_data = [] + for audio_file in Path(cfg.audio_dir).glob(f"**/*.{cfg.audio_type}"): + item = {"audio_filepath": str(audio_file.absolute()), "duration": 1000000, "offset": 0} + manifest_data.append(item) + parent_dir = Path(cfg.audio_dir) + else: + manifest_data = read_manifest(cfg.manifest_filepath) + parent_dir = Path(cfg.manifest_filepath).parent + + new_manifest_data = [] + + for item in manifest_data: + audio_file = Path(item["audio_filepath"]) + if len(str(audio_file)) < 255 and not audio_file.is_file() and not audio_file.is_absolute(): + new_audio_file = parent_dir / audio_file + if new_audio_file.is_file(): + item["audio_filepath"] = str(new_audio_file.absolute()) + else: + item["audio_filepath"] = os.path.expanduser(str(audio_file)) + else: + item["audio_filepath"] = os.path.expanduser(str(audio_file)) + item["label"] = "infer" + item["text"] = "-" + new_manifest_data.append(item) + + new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input.json")) + write_manifest(new_manifest_filepath, new_manifest_data) + return new_manifest_filepath + + +def extract_audio_features(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str: + file_list = [] + manifest_data = [] + out_dir = Path(cfg.output_dir) / Path("features") + new_manifest_filepath = str(Path(cfg.output_dir) / Path("temp_manifest_input_feature.json")) + + if Path(new_manifest_filepath).is_file(): + logging.info("Features already exist in output_dir, skipping feature extraction.") + return new_manifest_filepath + + has_feat = False + with open(manifest_filepath, 'r', encoding='utf-8') as fin: + for line in fin.readlines(): + item = json.loads(line.strip()) + manifest_data.append(item) + file_list.append(Path(item['audio_filepath']).stem) + if "feature_file" in item: + has_feat = True + if has_feat: + logging.info("Features already exist in manifest, skipping feature extraction.") + return manifest_filepath + + out_dir.mkdir(parents=True, exist_ok=True) + torch.set_grad_enabled(False) + if cfg.vad_model: + vad_model = init_frame_vad_model(cfg.vad_model) + else: + vad_model = EncDecClassificationModel.from_pretrained("vad_multilingual_marblenet") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vad_model = vad_model.to(device) + vad_model.eval() + vad_model.setup_test_data( + test_data_config={ + 'batch_size': 1, + 'vad_stream': False, + 'sample_rate': cfg.sample_rate, + 'manifest_filepath': manifest_filepath, + 'labels': [ + 'infer', + ], + 'num_workers': cfg.num_workers, + 'shuffle': False, + 'normalize_audio_db': cfg.normalize_audio_db, + } + ) + + logging.info(f"Extracting features on {len(file_list)} audio files...") + with record_fn("feat_extract_loop"): + for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): + test_batch = [x.to(vad_model.device) for x in test_batch] + with torch.amp.autocast(vad_model.device.type): + with record_fn("feat_extract_infer"): + processed_signal, processed_signal_length = vad_model.preprocessor( + input_signal=test_batch[0], + length=test_batch[1], + ) + with record_fn("feat_extract_other"): + processed_signal = processed_signal.squeeze(0)[:, :processed_signal_length] + processed_signal = processed_signal.cpu() + outpath = os.path.join(out_dir, file_list[i] + ".pt") + outpath = str(Path(outpath).absolute()) + torch.save(processed_signal, outpath) + manifest_data[i]["feature_file"] = outpath + del test_batch + + logging.info(f"Features saved at: {out_dir}") + write_manifest(new_manifest_filepath, manifest_data) + return new_manifest_filepath + + +def run_vad_inference(manifest_filepath: str, cfg: DictConfig, record_fn: Callable) -> str: + logging.info("Start VAD inference pipeline...") + if cfg.vad_type == "segment": + vad_model = init_vad_model(cfg.vad_model) + elif cfg.vad_type == "frame": + vad_model = init_frame_vad_model(cfg.vad_model) + else: + raise ValueError(f"Unknown VAD type: {cfg.vad_type}, supported types: ['segment', 'frame']") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vad_model = vad_model.to(device) + vad_model.eval() + + vad_yaml = Path(cfg.vad_config) + if not vad_yaml.is_file(): + raise ValueError(f"VAD config file not found: {cfg.vad_config}") + + with vad_yaml.open("r") as fp: + vad_cfg = yaml.safe_load(fp) + vad_cfg = DictConfig(vad_cfg) + + test_data_config = { + 'vad_stream': True, + 'manifest_filepath': manifest_filepath, + 'labels': [ + 'infer', + ], + 'num_workers': cfg.num_workers, + 'shuffle': False, + 'window_length_in_sec': vad_cfg.vad.parameters.window_length_in_sec, + 'shift_length_in_sec': vad_cfg.vad.parameters.shift_length_in_sec, + } + vad_model.setup_test_data(test_data_config=test_data_config, use_feat=True) + + pred_dir = Path(cfg.output_dir) / Path("vad_frame_pred") + if pred_dir.is_dir(): + logging.info(f"VAD frame-level prediction already exists: {pred_dir}, skipped") + else: + logging.info("Generating VAD frame-level prediction") + pred_dir.mkdir(parents=True) + t0 = time.time() + pred_dir = generate_vad_frame_pred( + vad_model=vad_model, + window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec, + shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec, + manifest_vad_input=manifest_filepath, + out_dir=str(pred_dir), + use_feat=True, + record_fn=record_fn, + ) + t1 = time.time() + logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") + logging.info( + f"Finished generating VAD frame level prediction with window_length_in_sec={vad_cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={vad_cfg.vad.parameters.shift_length_in_sec}" + ) + + frame_length_in_sec = vad_cfg.vad.parameters.shift_length_in_sec + # overlap smoothing filter + if vad_cfg.vad.parameters.smoothing: + # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments. + # smoothing_method would be either in majority vote (median) or average (mean) + logging.info("Generating predictions with overlapping input segments") + t0 = time.time() + smoothing_pred_dir = generate_overlap_vad_seq( + frame_pred_dir=pred_dir, + smoothing_method=vad_cfg.vad.parameters.smoothing, + overlap=vad_cfg.vad.parameters.overlap, + window_length_in_sec=vad_cfg.vad.parameters.window_length_in_sec, + shift_length_in_sec=vad_cfg.vad.parameters.shift_length_in_sec, + num_workers=cfg.num_workers, + out_dir=vad_cfg.smoothing_out_dir, + ) + logging.info( + f"Finish generating predictions with overlapping input segments with smoothing_method={vad_cfg.vad.parameters.smoothing} and overlap={vad_cfg.vad.parameters.overlap}" + ) + t1 = time.time() + logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") + pred_dir = smoothing_pred_dir + frame_length_in_sec = 0.01 + + # Turn frame-wise prediction into speech intervals + logging.info(f"Generating segment tables with postprocessing params: {vad_cfg.vad.parameters.postprocessing}") + segment_dir_name = "vad_rttm" + for key, val in vad_cfg.vad.parameters.postprocessing.items(): + segment_dir_name = segment_dir_name + "-" + str(key) + str(val) + + segment_dir = Path(cfg.output_dir) / Path(segment_dir_name) + if segment_dir.is_dir(): + logging.info(f"VAD speech segments already exists: {segment_dir}, skipped") + else: + segment_dir.mkdir(parents=True) + t0 = time.time() + segment_dir = generate_vad_segment_table( + vad_pred_dir=pred_dir, + postprocessing_params=vad_cfg.vad.parameters.postprocessing, + frame_length_in_sec=frame_length_in_sec, + num_workers=cfg.num_workers, + out_dir=segment_dir, + use_rttm=True, + ) + t1 = time.time() + logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") + logging.info("Finished generating RTTM files from VAD predictions.") + + rttm_map = {} + for filepath in Path(segment_dir).glob("*.rttm"): + rttm_map[filepath.stem] = str(filepath.absolute()) + + manifest_data = read_manifest(manifest_filepath) + for i in range(len(manifest_data)): + key = Path(manifest_data[i]["audio_filepath"]).stem + manifest_data[i]["rttm_file"] = rttm_map[key] + + new_manifest_filepath = str(Path(cfg.output_dir) / Path(f"temp_manifest_{segment_dir_name}.json")) + write_manifest(new_manifest_filepath, manifest_data) + return new_manifest_filepath + + +def generate_vad_frame_pred( + vad_model: EncDecClassificationModel, + window_length_in_sec: float, + shift_length_in_sec: float, + manifest_vad_input: str, + out_dir: str, + use_feat: bool = False, + record_fn: Callable = None, +) -> str: + """ + Generate VAD frame level prediction and write to out_dir + """ + time_unit = int(window_length_in_sec / shift_length_in_sec) + trunc = int(time_unit / 2) + trunc_l = time_unit - trunc + all_len = 0 + + data = [] + with open(manifest_vad_input, 'r', encoding='utf-8') as fin: + for line in fin.readlines(): + file = json.loads(line)['audio_filepath'].split("/")[-1] + data.append(file.split(".wav")[0]) + logging.info(f"Inference on {len(data)} audio files/json lines!") + + status = get_vad_stream_status(data) + + with record_fn("vad_infer_loop"): + for i, test_batch in enumerate(tqdm(vad_model.test_dataloader(), total=len(vad_model.test_dataloader()))): + test_batch = [x.to(vad_model.device) for x in test_batch] + with torch.amp.autocast(vad_model.device.type): + with record_fn("vad_infer_model"): + if use_feat: + log_probs = vad_model(processed_signal=test_batch[0], processed_signal_length=test_batch[1]) + else: + log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) + + with record_fn("vad_infer_other"): + probs = torch.softmax(log_probs, dim=-1) + if len(probs.shape) == 3: + # squeeze the batch dimension, since batch size is 1 + probs = probs.squeeze(0) # [1,T,C] -> [T,C] + pred = probs[:, 1] + + if window_length_in_sec == 0: + to_save = pred + elif status[i] == 'start': + to_save = pred[:-trunc] + elif status[i] == 'next': + to_save = pred[trunc:-trunc_l] + elif status[i] == 'end': + to_save = pred[trunc_l:] + else: + to_save = pred + + to_save = to_save.cpu().tolist() + all_len += len(to_save) + + outpath = os.path.join(out_dir, data[i] + ".frame") + with open(outpath, "a", encoding='utf-8') as fout: + for p in to_save: + fout.write(f'{p:0.4f}\n') + + del test_batch + if status[i] == 'end' or status[i] == 'single': + all_len = 0 + return out_dir + + +def init_asr_model(model_path: str) -> ASRModel: + if model_path.endswith('.nemo'): + logging.info(f"Using local ASR model from {model_path}") + asr_model = ASRModel.restore_from(restore_path=model_path) + elif model_path.endswith('.ckpt'): + asr_model = ASRModel.load_from_checkpoint(checkpoint_path=model_path) + else: + logging.info(f"Using NGC ASR model {model_path}") + asr_model = ASRModel.from_pretrained(model_name=model_path) + return asr_model + + +def run_asr_inference(manifest_filepath, cfg, record_fn) -> str: + logging.info("Start ASR inference pipeline...") + asr_model = init_asr_model(cfg.asr_model) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + asr_model = asr_model.to(device) + asr_model.eval() + + # Setup decoding strategy + decode_function = None + decoder_type = cfg.get("decoder_type", None) + if not hasattr(asr_model, 'change_decoding_strategy'): + raise ValueError(f"ASR model {cfg.asr_model} does not support decoding strategy.") + if decoder_type is not None: # Hybrid model + if decoder_type == 'rnnt': + cfg.rnnt_decoding.fused_batch_size = -1 + cfg.rnnt_decoding.compute_langs = cfg.compute_langs + asr_model.change_decoding_strategy(cfg.rnnt_decoding, decoder_type=decoder_type) + decode_function = asr_model.decoding.rnnt_decoder_predictions_tensor + elif decoder_type == 'ctc': + asr_model.change_decoding_strategy(cfg.ctc_decoding, decoder_type=decoder_type) + decode_function = asr_model.decoding.ctc_decoder_predictions_tensor + else: + raise ValueError( + f"Unknown decoder type for hybrid model: {decoder_type}, supported types: ['rnnt', 'ctc']" + ) + elif hasattr(asr_model, 'joint'): # RNNT model + cfg.rnnt_decoding.fused_batch_size = -1 + cfg.rnnt_decoding.compute_langs = cfg.compute_langs + asr_model.change_decoding_strategy(cfg.rnnt_decoding) + decode_function = asr_model.decoding.rnnt_decoder_predictions_tensor + else: + asr_model.change_decoding_strategy(cfg.ctc_decoding) + decode_function = asr_model.decoding.ctc_decoder_predictions_tensor + + # Compute output filename + if cfg.output_filename is None: + # create default output filename + if cfg.pred_name_postfix is not None: + cfg.output_filename = cfg.manifest_filepath.replace('.json', f'_{cfg.pred_name_postfix}.json') + else: + tag = f"{cfg.normalize}_{cfg.normalize_type}" + if cfg.use_rttm: + vad_tag = Path(manifest_filepath).stem + vad_tag = vad_tag[len("temp_manifest_vad_rttm_") :] + if cfg.rttm_mode == "mask": + tag += f"-mask{cfg.feat_mask_val}-{vad_tag}" + else: + tag += f"-dropframe-{vad_tag}" + cfg.output_filename = cfg.manifest_filepath.replace('.json', f'-{Path(cfg.asr_model).stem}-{tag}.json') + cfg.output_filename = Path(cfg.output_dir) / Path(cfg.output_filename).name + + logging.info("Setting up dataloader for ASR...") + data_config = { + "manifest_filepath": manifest_filepath, + "normalize": cfg.normalize, + "normalize_type": cfg.normalize_type, + "use_rttm": cfg.use_rttm, + "rttm_mode": cfg.rttm_mode, + "feat_mask_val": cfg.feat_mask_val, + "frame_unit_time_secs": cfg.frame_unit_time_secs, + } + logging.info(f"use_rttm = {cfg.use_rttm}, rttm_mode = {cfg.rttm_mode}, feat_mask_val = {cfg.feat_mask_val}") + + if hasattr(asr_model, "tokenizer"): + dataset = feature_to_text_dataset.get_bpe_dataset(config=data_config, tokenizer=asr_model.tokenizer) + else: + data_config["labels"] = asr_model.decoder.vocabulary + dataset = feature_to_text_dataset.get_char_dataset(config=data_config) + + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=cfg.batch_size, + collate_fn=dataset._collate_fn, + drop_last=False, + shuffle=False, + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + ) + + logging.info("Start transcribing...") + hypotheses = [] + all_hypotheses = [] + t0 = time.time() + with torch.amp.autocast(asr_model.device.type): + with torch.no_grad(): + with record_fn("asr_infer_loop"): + for test_batch in tqdm(dataloader, desc="Transcribing"): + with record_fn("asr_infer_model"): + outputs = asr_model.forward( + processed_signal=test_batch[0].to(device), + processed_signal_length=test_batch[1].to(device), + ) + + with record_fn("asr_infer_other"): + logits, logits_len = outputs[0], outputs[1] + + current_hypotheses, all_hyp = decode_function( + logits, + logits_len, + return_hypotheses=False, + ) + if isinstance(current_hypotheses, tuple) and len(current_hypotheses) == 2: + current_hypotheses = current_hypotheses[0] # handle RNNT output + + hypotheses += current_hypotheses + if all_hyp is not None: + all_hypotheses += all_hyp + else: + all_hypotheses += current_hypotheses + + del logits + del test_batch + t1 = time.time() + logging.info(f"Time elapsed: {t1 - t0: .2f} seconds") + + logging.info("Finished transcribing.") + # Save output to manifest + input_manifest_data = read_manifest(manifest_filepath) + manifest_data = read_manifest(cfg.manifest_filepath) + + if "text" not in manifest_data[0]: + has_groundtruth = False + else: + has_groundtruth = True + + groundtruth = [] + for i in range(len(manifest_data)): + if has_groundtruth: + groundtruth.append(manifest_data[i]["text"]) + manifest_data[i]["pred_text"] = hypotheses[i] + manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"] + if "rttm_file" in input_manifest_data[i]: + manifest_data[i]["feature_file"] = input_manifest_data[i]["feature_file"] + + write_manifest(cfg.output_filename, manifest_data) + + if not has_groundtruth: + hypotheses = " ".join(hypotheses) + words = hypotheses.split() + chars = "".join(words) + logging.info("-----------------------------------------") + logging.info(f"Number of generated characters={len(chars)}") + logging.info(f"Number of generated words={len(words)}") + logging.info("-----------------------------------------") + else: + wer_score = word_error_rate(hypotheses=hypotheses, references=groundtruth) + cer_score = word_error_rate(hypotheses=hypotheses, references=groundtruth, use_cer=True) + logging.info("-----------------------------------------") + logging.info(f"WER={wer_score:.4f}, CER={cer_score:.4f}") + logging.info("-----------------------------------------") + + logging.info(f"ASR output saved at {cfg.output_filename}") + return cfg.output_filename + + +if __name__ == "__main__": + main() diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 8c498cf2..6ddf27f4 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -214,7 +214,6 @@ def run_processors(cfg): use_dask_flag = global_use_dask else: use_dask_flag = flag - processor = hydra.utils.instantiate(processor_cfg) processor.use_dask = use_dask_flag # running runtime tests to fail right-away if something is not @@ -222,10 +221,10 @@ def run_processors(cfg): processor.test() processors.append(processor) - # Start Dask client if any processor requires it dask_client = None if any(p.use_dask for p in processors): + try: num_cpus = psutil.cpu_count(logical=False) or 4 logger.info(f"Starting Dask client with {num_cpus} workers") diff --git a/tests/test_cfg_end_to_end_tests.py b/tests/test_cfg_end_to_end_tests.py index 9d860ce9..05fdfcb3 100644 --- a/tests/test_cfg_end_to_end_tests.py +++ b/tests/test_cfg_end_to_end_tests.py @@ -88,6 +88,22 @@ def data_check_fn_uzbekvoice(raw_data_dir: str) -> None: else: raise ValueError(f"No such file {str(expected_file)} at {str(raw_data_dir)}") +def data_check_fn_unlabeled(raw_data_dir: str) -> None: + """Checks for data and sets it up for unlabeled processing. + + Args: + raw_data_dir: Directory where data should be + language: Language code (e.g. 'portuguese') + """ + # Get the MLS directory path (one level up from unlabeled) + if (Path(raw_data_dir) / "unlabeled").exists(): + return + expected_file = Path(raw_data_dir) / "unlabeled.tar.gz" + if not expected_file.exists(): + raise ValueError(f"No such file {str(expected_file)}") + with tarfile.open(expected_file, 'r:gz') as tar: + tar.extractall(path=raw_data_dir) + def data_check_fn_armenian_toloka_pipeline_start(raw_data_dir: str) -> None: """Checks for the Armenian Toloka test data. @@ -122,129 +138,134 @@ def data_check_fn_armenian_toloka_pipeline_get_final_res(raw_data_dir: str) -> N def get_test_cases() -> List[Tuple[str, Callable]]: return [ TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/spanish/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="spanish"), - ), + config_path=f"{DATASET_CONFIGS_ROOT}/spanish/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="spanish"), + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/spanish_pc/mcv12/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-12.0-2022-12-07-es") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/spanish_pc/mcv12/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-12.0-2022-12-07-es") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/italian/voxpopuli/config.yaml", - data_check_fn=data_check_fn_voxpopuli - ), + config_path=f"{DATASET_CONFIGS_ROOT}/italian/voxpopuli/config.yaml", + data_check_fn=data_check_fn_voxpopuli + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/italian/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="italian") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/italian/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="italian") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mls/config.yaml", - data_check_fn=partial(data_check_fn_mls, language="portuguese") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mls/config.yaml", + data_check_fn=partial(data_check_fn_mls, language="portuguese") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-15.0-2023-09-08-pt") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="cv-corpus-15.0-2023-09-08-pt") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mtedx/config.yaml", - data_check_fn=partial(data_check_fn_mtedx, language_id="pt") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/mtedx/config.yaml", + data_check_fn=partial(data_check_fn_mtedx, language_id="pt") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/coraa/config.yaml", - data_check_fn=data_check_fn_coraa - ), + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/coraa/config.yaml", + data_check_fn=data_check_fn_coraa + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/slr83/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/slr83/config.yaml", + data_check_fn=lambda raw_data_dir: True + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/coraal/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/coraal/config.yaml", + data_check_fn=lambda raw_data_dir: True + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/english/librispeech/config.yaml", - data_check_fn=data_check_fn_librispeech - ), + config_path=f"{DATASET_CONFIGS_ROOT}/english/librispeech/config.yaml", + data_check_fn=data_check_fn_librispeech + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/fleurs/config.yaml", + data_check_fn=data_check_fn_fleurs + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/text_mcv/config.yaml", - data_check_fn=lambda raw_data_dir: True - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/text_mcv/config.yaml", + data_check_fn=lambda raw_data_dir: True + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/audio_books/config.yaml", - data_check_fn=lambda raw_data_dir: True, - fields_to_ignore=['text'], - ), + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/audio_books/config.yaml", + data_check_fn=lambda raw_data_dir: True, + fields_to_ignore=['text'], + ), TestCase( - f"{DATASET_CONFIGS_ROOT}/kazakh/mcv/config.yaml", - partial(data_check_fn_mcv, archive_file_stem="mcv_kk") - ), + f"{DATASET_CONFIGS_ROOT}/kazakh/mcv/config.yaml", + partial(data_check_fn_mcv, archive_file_stem="mcv_kk") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr140/config.yaml", - data_check_fn=data_check_fn_slr140 - ), + config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr140/config.yaml", + data_check_fn=data_check_fn_slr140 + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr102/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="slr102_kk.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/slr102/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="slr102_kk.tar.gz") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/ksc2/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="ksc2_kk.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/kazakh/ksc2/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="ksc2_kk.tar.gz") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv_uz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv_uz") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/uzbekvoice/config.yaml", - data_check_fn=data_check_fn_uzbekvoice - ), + config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/uzbekvoice/config.yaml", + data_check_fn=data_check_fn_uzbekvoice + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/uzbek/fleurs/config.yaml", + data_check_fn=data_check_fn_fleurs + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config_filter_noisy_train.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz"), - reference_manifest_filename="test_data_reference_filter.json" - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/masc/config_filter_noisy_train.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="masc.tar.gz"), + reference_manifest_filename="test_data_reference_filter.json" + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mcv/config.yaml", - data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv.ar") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mcv/config.yaml", + data_check_fn=partial(data_check_fn_mcv, archive_file_stem="mcv.ar") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/fleurs/config.yaml", - data_check_fn=data_check_fn_fleurs - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/fleurs/config.yaml", + data_check_fn=data_check_fn_fleurs + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mediaspeech/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="AR.tar.gz") - ), + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/mediaspeech/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="AR.tar.gz") + ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/arabic/everyayah/config.yaml", - data_check_fn=partial(data_check_fn_generic, file_name="everyayah.hf") + config_path=f"{DATASET_CONFIGS_ROOT}/arabic/everyayah/config.yaml", + data_check_fn=partial(data_check_fn_generic, file_name="everyayah.hf") ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_start.yaml", - data_check_fn=data_check_fn_armenian_toloka_pipeline_start, - fields_to_ignore=['source_filepath'], - processors_to_run="2:14", - reference_manifest_filename="pipeline_start/test_data_reference.json" + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_start.yaml", + data_check_fn=data_check_fn_armenian_toloka_pipeline_start, + fields_to_ignore=['source_filepath'], + processors_to_run="2:14", + reference_manifest_filename="pipeline_start/test_data_reference.json" ), TestCase( - config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_get_final_res.yaml", - data_check_fn=data_check_fn_armenian_toloka_pipeline_get_final_res, - reference_manifest_filename="pipeline_get_final_res/test_data_reference.json", - fields_to_ignore=['audio_filepath', 'duration'], - processors_to_run="1:6" + config_path=f"{DATASET_CONFIGS_ROOT}/armenian/toloka/pipeline_get_final_res.yaml", + data_check_fn=data_check_fn_armenian_toloka_pipeline_get_final_res, + reference_manifest_filename="pipeline_get_final_res/test_data_reference.json", + fields_to_ignore=['audio_filepath', 'duration'], + processors_to_run="1:6" ), + TestCase( + config_path=f"{DATASET_CONFIGS_ROOT}/portuguese/unlabeled/config.yaml", + data_check_fn=partial(data_check_fn_unlabeled), + fields_to_ignore=['duration'], + ), TestCase( config_path=f"{DATASET_CONFIGS_ROOT}/english/hifitts2/config_22khz.yaml", data_check_fn=partial(data_check_fn_generic, file_name="manifest_22khz.json"), @@ -438,4 +459,4 @@ def test_get_e2e_test_data_path(tmp_path): assert mock_bucket.download_file.call_count == 2 if __name__ == "__main__": - pytest.main([__file__, "-v", "--durations=0"]) \ No newline at end of file + pytest.main([__file__, "-v", "--durations=0"])