diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml index 694cc7f8..a85181f5 100644 --- a/.github/workflows/doc-build.yml +++ b/.github/workflows/doc-build.yml @@ -33,6 +33,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install -r requirements/main.txt pip install -r requirements/docs.txt - name: Build docs with sphinx run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aed66873..4af78f3b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install -r requirements/main.txt pip install -r requirements/docs.txt pip install nemo_text_processing python -m pip cache purge diff --git a/dataset_configs/commoncrawl/big_de.yaml b/dataset_configs/commoncrawl/big_de.yaml new file mode 100644 index 00000000..82fb85c7 --- /dev/null +++ b/dataset_configs/commoncrawl/big_de.yaml @@ -0,0 +1,210 @@ +processors_to_run: "0:" +lang: de +base_dir: /path/to/dataset/folder +workspace_dir: ${base_dir}/${lang} + +processors: + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + input_manifest_file: /path/to/dataset/folder/manifest11.json + output_manifest_file: ${workspace_dir}/manifest0.json + input_field: audio_lang + target_value: ${lang} + + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest1.json + input_field: text_lang + target_value: ${lang} + + - _target_: sdp.processors.ASRInference + output_manifest_file: ${workspace_dir}/manifest2.json + pretrained_model: nvidia/stt_${lang}_fastconformer_hybrid_large_pc + batch_size: 64 + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest3.json + duplicate_fields: {"text":"orig_text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest4.json + text_key: text + regex_params_list: + - {"pattern": '\[(.*?)\]', "repl": ' '} + - {"pattern": '\((.*?)\)', "repl": ' '} + - {"pattern": "^[\\s]*\\*(.*?)\\*[\\s]*$", "repl": "\\1"} + - {"pattern": 'î', "repl": "i"} + - {"pattern": 'ì', "repl": "i"} + - {"pattern": 'í', "repl": "i"} + - {"pattern": '‚', "repl": ","} + - {"pattern": "’", "repl": "'"} + - {"pattern": "[-–—]", "repl": " "} + - {"pattern": '―', "repl": "-"} + - {"pattern": '—', "repl": "-"} + - {"pattern": '⁺', "repl": "+"} + - {"pattern": '“', "repl": '"'} + - {"pattern": '”', "repl": '"'} + - {"pattern": '…', "repl": '.'} + - {"pattern": '‘', "repl": "'"} + - {"pattern": '′', "repl": "'"} + - {"pattern": '`', "repl": "'"} + - {"pattern": '⁻', "repl": "-"} + - {"pattern": '‑', "repl": "-"} + - {"pattern": '¶', "repl": ' '} + - {"pattern": '«', "repl": '"'} + - {"pattern": '»', "repl": '"'} + - {"pattern": '„', "repl": '"'} + - {"pattern": '®', "repl": ' '} + - {"pattern": '@', "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighLowWordrate + output_manifest_file: ${workspace_dir}/manifest5.json + high_wordrate_threshold: 100 + low_wordrate_threshold: 0.01 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest6.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: ${workspace_dir}/manifest7.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=${lang} --n_jobs=-1 --batch_size=600 --manifest_text_field=text --cache_dir=${workspace_dir}/cache \ + --whitelist=/home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/${lang}/data/whitelist.tsv" + + - _target_: sdp.processors.RenameFields + output_manifest_file: ${workspace_dir}/manifest8.json + rename_fields: {"normalized":"text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest9.json + text_key: text + regex_params_list: + - {"pattern": '\s(\\x[a-h][0-9]){1,}\s', "repl": ' '} + - {"pattern": '(\\x[a-h][0-9]){1,}', "repl": ''} + - {"pattern": '\.{3}', "repl": '.'} + - {"pattern": '\$', "repl": ""} + - {"pattern": "'", "repl": " "} + - {"pattern": "[^a-zA-ZäöüÄÖÜßẞ.,?]", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest10.json + duplicate_fields: {"text":"text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest11.json + text_key: text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest12.json + text_key: text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest13.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest14.json + duplicate_fields: {"pred_text":"pred_text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest15.json + text_key: pred_text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest16.json + text_key: pred_text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest17.json + text_key: text + pred_text_key: pred_text + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest18.json + text_key: text + pred_text_key: pred_text + cer_threshold: 30 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + fields_to_keep: ["audio_filepath", "duration", "text_pc"] + + - _target_: sdp.processors.RenameFields + rename_fields: {"text_pc":"text"} + + - _target_: sdp.processors.SubRegex + text_key: text + regex_params_list: + - {"pattern": "\\s+\\?", "repl": "?"} + - {"pattern": "\\s+\\.", "repl": "."} + - {"pattern": "\\s+,", "repl": ","} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.datasets.commoncrawl.ManifestToUtf8 + + - _target_: sdp.processors.AddConstantFields + output_manifest_file: ${workspace_dir}/manifest_${lang}.json + fields: {"lang": '${lang}'} + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + output_manifest_file: ${workspace_dir}/manifest_${lang}_train.json + lang: ${lang} + data_split: train + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_train/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_train.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_dev.json + lang: ${lang} + data_split: dev + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_dev/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_dev.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_test.json + lang: ${lang} + data_split: test + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_test/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_test.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ \ No newline at end of file diff --git a/dataset_configs/commoncrawl/big_en.yaml b/dataset_configs/commoncrawl/big_en.yaml new file mode 100644 index 00000000..bc755739 --- /dev/null +++ b/dataset_configs/commoncrawl/big_en.yaml @@ -0,0 +1,365 @@ +processors_to_run: "0:" +lang: en +base_dir: /path/to/dataset/folder +workspace_dir: ${base_dir}/${lang} + +processors: + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + input_manifest_file: /path/to/dataset/folder/manifest11.json + output_manifest_file: ${workspace_dir}/manifest0.json + input_field: audio_lang + target_value: ${lang} + + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest1.json + input_field: text_lang + target_value: ${lang} + + - _target_: sdp.processors.ASRInference + output_manifest_file: ${workspace_dir}/manifest2.json + pretrained_model: nvidia/stt_${lang}_fastconformer_hybrid_large_pc + batch_size: 64 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest3.json + regex_patterns: + # - '://' + # - "(\\s)+(www)\\.[a-zA-Z0-9/]+(\\s|$)+" + # - '\\x' + - "www\\.wiki\\s" + - "www\\.usgs\\.\\s" + # - 'é' + # - 'ô' + # - '×' + # - 'š' + # - 'ö' + # - 'ß' + # - 'ä' + # - 'ü' + # - '\u202a' + # - 'č' + # - 'ć' + # - 'á' + # - 'ã' + # - 'â' + # - 'ï' + # - '\u2060' + # - 'ñ' + # - 'ŵ' + # - 'à' + # - 'ù' + # - 'ò' + # - 'ó' + # - 'ő' + # - 'ê' + # - 'ă' + # - 'ú' + # - 'µ' + # - '¿' + # - '¡' + # - 'ë' + # - "è" + # - "é" + # - "È" + # - "É" + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest4.json + duplicate_fields: {"text":"orig_text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest5.json + regex_params_list: + - {"pattern": '\[(.*?)\]', "repl": ' '} + - {"pattern": '\((.*?)\)', "repl": ' '} + - {"pattern": "^[\\s]*\\*(.*?)\\*[\\s]*$", "repl": "\\1"} + - {"pattern": 'î', "repl": "i"} + - {"pattern": 'ì', "repl": "i"} + - {"pattern": 'í', "repl": "i"} + - {"pattern": '‚', "repl": ","} + - {"pattern": "’", "repl": "'"} + - {"pattern": "[-–—]", "repl": " "} + - {"pattern": '―', "repl": "-"} + - {"pattern": '—', "repl": "-"} + - {"pattern": '⁺', "repl": "+"} + - {"pattern": '“', "repl": '"'} + - {"pattern": '”', "repl": '"'} + - {"pattern": '…', "repl": '.'} + - {"pattern": '‘', "repl": "'"} + - {"pattern": '′', "repl": "'"} + - {"pattern": '`', "repl": "'"} + - {"pattern": '⁻', "repl": "-"} + - {"pattern": '‑', "repl": "-"} + - {"pattern": '¶', "repl": ' '} + - {"pattern": '«', "repl": '"'} + - {"pattern": '»', "repl": '"'} + - {"pattern": '„', "repl": '"'} + - {"pattern": '®', "repl": ' '} + # - {"pattern": "%", "repl": ' '} + - {"pattern": '@', "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighLowWordrate + output_manifest_file: ${workspace_dir}/manifest6.json + high_wordrate_threshold: 100 + low_wordrate_threshold: 0.01 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest7.json + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: ${workspace_dir}/manifest8.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=en --n_jobs=-1 --batch_size=600 --manifest_text_field=text --cache_dir=${workspace_dir}/cache --overwrite_cache \ + --whitelist=/home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/${lang}/data/whitelist/asr_with_pc.tsv" + + - _target_: sdp.processors.RenameFields + output_manifest_file: ${workspace_dir}/manifest9.json + rename_fields: {"normalized":"text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest10.json + text_key: text + regex_params_list: + - {"pattern": "^\\s*'+\\s(.*?)\\s*'+\\s*$", "repl": "\\1"} + - {"pattern": "^\\s*'*\\s*", "repl": ""} + - {"pattern": "'{2,}", "repl": "'"} + - {"pattern": '\s(\\x[a-h][0-9]){1,}\s', "repl": ' '} + - {"pattern": '(\\x[a-h][0-9]){1,}', "repl": ''} + - {"pattern": '\.{3}', "repl": '.'} + - {"pattern": '!', "repl": '.'} + - {"pattern": '\$', "repl": ""} + - {"pattern": "[^A-Za-z'.,?]", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + test_cases: + - {input: {text: "' jupiter and venus both shining in the golden rosy sky"}, output: {text: "jupiter and venus both shining in the golden rosy sky"}} + - {input: {text: "' may all the gold i have ever dreamed of be yours '"}, output: {text: "may all the gold i have ever dreamed of be yours"}} + - {input: {text: "''cause it''s an adult novel versus ya"}, output: {text: "cause it's an adult novel versus ya"}} + + - _target_: sdp.processors.DropHighLowWordrate + output_manifest_file: ${workspace_dir}/manifest11.json + high_wordrate_threshold: 100 + low_wordrate_threshold: 0.01 + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest12.json + duplicate_fields: {"text":"text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest13.json + text_key: text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest14.json + text_key: text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest15.json + duplicate_fields: {"pred_text":"pred_text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest16.json + text_key: pred_text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest17.json + text_key: pred_text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest18.json + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest19.json + text_key: text + pred_text_key: pred_text + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest20.json + text_key: text + pred_text_key: pred_text + cer_threshold: 30 + + - _target_: sdp.processors.RenameFields + input_manifest_file: ${workspace_dir}/manifest18.json + output_manifest_file: ${workspace_dir}/manifest21.json + rename_fields: {"audios":"source_audio"} + + - _target_: sdp.processors.datasets.commoncrawl.AlignerSubprocess + output_manifest_file: ${workspace_dir}/manifest22.json + input_manifest_arg: "manifest_filepath" + output_field: "alignment" + cmd: "python3 /home/nkarpov/workspace/NeMo/tools/nemo_forced_aligner/align.py pretrained_name=stt_en_fastconformer_hybrid_large_pc \ + output_dir=${workspace_dir} batch_size=1 additional_segment_grouping_separator=|" + + - _target_: sdp.processors.datasets.commoncrawl.SplitByAligner + output_manifest_file: ${workspace_dir}/manifest23.json + splited_audio_dir: ${workspace_dir}/nfa + input_field: source_audio + output_field: nfa_filepath + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest24.json + duplicate_fields: {"audio_filepath":"audio_filepath_base"} + + - _target_: sdp.processors.RenameFields + output_manifest_file: ${workspace_dir}/manifest25.json + rename_fields: {"nfa_filepath":"audio_filepath"} + + - _target_: sdp.processors.DropHighLowDuration + output_manifest_file: ${workspace_dir}/manifest26.json + high_duration_threshold: 60 + low_duration_threshold: 0.01 + duration_key: nfa_duration + + - _target_: sdp.processors.ASRInference + output_manifest_file: ${workspace_dir}/manifest27.json + pretrained_model: nvidia/stt_${lang}_fastconformer_hybrid_large_pc + batch_size: 64 + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest28.json + duplicate_fields: {"pred_text":"pred_text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest29.json + text_key: pred_text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest30.json + text_key: pred_text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest31.json + text_key: text + pred_text_key: pred_text + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest32.json + text_key: text + pred_text_key: pred_text + cer_threshold: 30 + + + - _target_: sdp.processors.datasets.commoncrawl.JoinBy + input_manifest_file: ${workspace_dir}/manifest21.json + output_manifest_file: ${workspace_dir}/manifest33.json + input_field: source_audio + + + - _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: ${workspace_dir}/manifest34.json + input_manifest_arg: "--data_manifest" + output_manifest_arg: "--out_manifest" + arg_separator: "=" + cmd: "python /home/nkarpov/workspace/NvLLMOps/nvllmops/stages/asr/data_segmentation/ds_align/ds_align.py \ + --splits_dir=/mnt/ssd8/cc_sdp/en/dsa \ + --stt-model-path=/home/nkarpov/ckpts/en/stt_en_conformer_ctc_large_1.1/stt_en_conformer_ctc_large.nemo \ + --stt-model-type=CTC \ + --min-audio-duration=2 \ + --max-audio-duration=40 \ + --asr-batch-size=32" + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest35.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest36.json + text_key: text + pred_text_key: text_asr_pred + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest37.json + text_key: text + pred_text_key: text_asr_pred + cer_threshold: 30 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + input_manifest_file: ${workspace_dir}/manifest20.json + fields_to_keep: ["audio_filepath", "duration", "text_pc"] + + - _target_: sdp.processors.RenameFields + rename_fields: {"text_pc":"text"} + + - _target_: sdp.processors.SubRegex + text_key: text + regex_params_list: + - {"pattern": "\\s+\\?", "repl": "?"} + - {"pattern": "\\s+\\.", "repl": "."} + - {"pattern": "\\s+,", "repl": ","} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.AddConstantFields + output_manifest_file: ${workspace_dir}/manifest_${lang}.json + fields: {"lang": '${lang}'} + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + output_manifest_file: ${workspace_dir}/manifest_${lang}_train.json + lang: ${lang} + data_split: train + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_train/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_train.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_dev.json + lang: ${lang} + data_split: dev + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_dev/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_dev.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_test.json + lang: ${lang} + data_split: test + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_test/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_test.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ \ No newline at end of file diff --git a/dataset_configs/commoncrawl/big_fr.yaml b/dataset_configs/commoncrawl/big_fr.yaml new file mode 100644 index 00000000..92e958b1 --- /dev/null +++ b/dataset_configs/commoncrawl/big_fr.yaml @@ -0,0 +1,216 @@ +processors_to_run: "0:" +lang: fr +base_dir: /path/to/dataset/folder +workspace_dir: ${base_dir}/${lang} + +processors: + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + input_manifest_file: ${base_dir}/manifest11.json + output_manifest_file: ${workspace_dir}/manifest0.json + input_field: audio_lang + target_value: ${lang} + + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest1.json + input_field: text_lang + target_value: ${lang} + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest2.json + duplicate_fields: {"text":"orig_text"} + + - _target_: sdp.processors.ASRInference + output_manifest_file: ${workspace_dir}/manifest3.json + pretrained_model: nvidia/stt_fr_fastconformer_hybrid_large_pc + batch_size: 32 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest4.json + regex_patterns: + - '\\x' + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest5.json + text_key: text + regex_params_list: + - {"pattern": '\[(.*?)\]', "repl": ' '} + - {"pattern": '\((.*?)\)', "repl": ' '} + - {"pattern": "^[\\s]*\\*(.*?)\\*[\\s]*$", "repl": "\\1"} + - {"pattern": "\\\\x[a-f\\d]{1,}", "repl": " "} + - {"pattern": '‚', "repl": ","} + - {"pattern": "’", "repl": "'"} + - {"pattern": "[-–—]", "repl": " "} + - {"pattern": '―', "repl": "-"} + - {"pattern": '—', "repl": "-"} + - {"pattern": '⁺', "repl": "+"} + - {"pattern": '“', "repl": '"'} + - {"pattern": '”', "repl": '"'} + - {"pattern": '…', "repl": '.'} + - {"pattern": '‘', "repl": "'"} + - {"pattern": '′', "repl": "'"} + - {"pattern": '`', "repl": "'"} + - {"pattern": '⁻', "repl": "-"} + - {"pattern": '‑', "repl": "-"} + - {"pattern": '¶', "repl": ' '} + - {"pattern": '«', "repl": '"'} + - {"pattern": '»', "repl": '"'} + - {"pattern": '„', "repl": '"'} + - {"pattern": '®', "repl": ' '} + - {"pattern": '@', "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighLowWordrate + output_manifest_file: ${workspace_dir}/manifest6.json + text_key: text + high_wordrate_threshold: 100 + low_wordrate_threshold: 0.01 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest7.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: ${workspace_dir}/manifest8.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=${lang} --n_jobs=-1 --batch_size=600 --manifest_text_field=text --cache_dir=${workspace_dir}/cache \ + --whitelist=/home/nkarpov/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/${lang}/data/whitelist.tsv" + + - _target_: sdp.processors.RenameFields + output_manifest_file: ${workspace_dir}/manifest9.json + rename_fields: {"normalized":"text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest10.json + regex_params_list: + - {"pattern": "^\\s*'+\\s(.*?)\\s*'+\\s*$", "repl": "\\1"} + - {"pattern": "^\\s*'*\\s*", "repl": ""} + - {"pattern": "'{2,}", "repl": "'"} + - {"pattern": '\s(\\x[a-h][0-9]){1,}\s', "repl": ' '} + - {"pattern": '(\\x[a-h][0-9]){1,}', "repl": ''} + - {"pattern": '\.{3}', "repl": '.'} + - {"pattern": '\$', "repl": ""} + - {"pattern": "[^a-zA-ZàâçéèêëîïôûùüÿæœÀÂÇÉÈÊËÎÏÔÛÙÜŸÆŒ.,?'-]", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest11.json + duplicate_fields: {"text":"text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest12.json + text_key: text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest13.json + text_key: text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest14.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest15.json + duplicate_fields: {"pred_text":"pred_text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest16.json + text_key: pred_text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest17.json + text_key: pred_text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest18.json + text_key: text + pred_text_key: pred_text + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest19.json + text_key: text + pred_text_key: pred_text + cer_threshold: 30 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + fields_to_keep: ["audio_filepath", "duration", "text_pc"] + + - _target_: sdp.processors.RenameFields + rename_fields: {"text_pc":"text"} + + - _target_: sdp.processors.SubRegex + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + text_key: text + regex_params_list: + - {"pattern": "\\s+\\?", "repl": "?"} + - {"pattern": "\\s+\\.", "repl": "."} + - {"pattern": "\\s+,", "repl": ","} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.datasets.commoncrawl.ManifestToUtf8 + + - _target_: sdp.processors.AddConstantFields + output_manifest_file: ${workspace_dir}/manifest_${lang}.json + fields: {"lang": '${lang}'} + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + output_manifest_file: ${workspace_dir}/manifest_${lang}_train.json + lang: ${lang} + data_split: train + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_train/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_train.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_dev.json + lang: ${lang} + data_split: dev + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_dev/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_dev.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_test.json + lang: ${lang} + data_split: test + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_test/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_test.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ \ No newline at end of file diff --git a/dataset_configs/commoncrawl/big_pl.yaml b/dataset_configs/commoncrawl/big_pl.yaml new file mode 100644 index 00000000..ec1d6d96 --- /dev/null +++ b/dataset_configs/commoncrawl/big_pl.yaml @@ -0,0 +1,196 @@ +processors_to_run: "0:" +lang: pl +base_dir: /path/to/dataset/folder +workspace_dir: ${base_dir}/${lang} + +processors: + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + input_manifest_file: ${base_dir}/manifest11.json + output_manifest_file: ${workspace_dir}/manifest0.json + input_field: audio_lang + target_value: ${lang} + + - _target_: sdp.processors.datasets.commoncrawl.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest1.json + input_field: text_lang + target_value: ${lang} + + - _target_: sdp.processors.ASRInference + output_manifest_file: ${workspace_dir}/manifest2.json + pretrained_model: nvidia/stt_pl_fastconformer_hybrid_large_pc + batch_size: 64 + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest3.json + duplicate_fields: {"text":"orig_text"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest4.json + text_key: text + regex_params_list: + - {"pattern": '\[(.*?)\]', "repl": ' '} + - {"pattern": '\((.*?)\)', "repl": ' '} + - {"pattern": "^[\\s]*\\*(.*?)\\*[\\s]*$", "repl": "\\1"} + - {"pattern": '‚', "repl": ","} + - {"pattern": "’", "repl": "'"} + - {"pattern": "[-–—]", "repl": " "} + - {"pattern": '―', "repl": "-"} + - {"pattern": '—', "repl": "-"} + - {"pattern": '⁺', "repl": "+"} + - {"pattern": '“', "repl": '"'} + - {"pattern": '”', "repl": '"'} + - {"pattern": '…', "repl": '.'} + - {"pattern": '‘', "repl": "'"} + - {"pattern": '′', "repl": "'"} + - {"pattern": '`', "repl": "'"} + - {"pattern": '⁻', "repl": "-"} + - {"pattern": '‑', "repl": "-"} + - {"pattern": '¶', "repl": ' '} + - {"pattern": '«', "repl": '"'} + - {"pattern": '»', "repl": '"'} + - {"pattern": '„', "repl": '"'} + - {"pattern": '®', "repl": ' '} + - {"pattern": '@', "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighLowWordrate + output_manifest_file: ${workspace_dir}/manifest5.json + high_wordrate_threshold: 100 + low_wordrate_threshold: 0.01 + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest6.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest7.json + text_key: text + regex_params_list: + - {"pattern": "^\\s*'+\\s(.*?)\\s*'+\\s*$", "repl": "\\1"} + - {"pattern": "^\\s*'*\\s*", "repl": ""} + - {"pattern": "'{2,}", "repl": "'"} + - {"pattern": '\s(\\x[a-h][0-9]){1,}\s', "repl": ' '} + - {"pattern": '(\\x[a-h][0-9]){1,}', "repl": ''} + - {"pattern": '\.{3}', "repl": '.'} + - {"pattern": '\$', "repl": ""} + - {"pattern": "[^a-pr-uwy-zA-PR-UWY-ZąćęłńóśźżĄĆĘŁŃÓŚŹŻ.,?]", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest8.json + duplicate_fields: {"text":"text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest9.json + text_key: text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest10.json + text_key: text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest11.json + text_key: text + regex_patterns: + - "^\\s*$" + + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest12.json + duplicate_fields: {"pred_text":"pred_text_pc"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest13.json + text_key: pred_text + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest14.json + text_key: pred_text + regex_params_list: + - {"pattern": "[\\?\\.]", "repl": " "} + - {"pattern": ",", "repl": " "} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest15.json + text_key: text + pred_text_key: pred_text + wer_threshold: 75 + + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest16.json + text_key: text + pred_text_key: pred_text + cer_threshold: 30 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + fields_to_keep: ["audio_filepath", "duration", "text_pc"] + + - _target_: sdp.processors.RenameFields + rename_fields: {"text_pc":"text"} + + - _target_: sdp.processors.SubRegex + text_key: text + regex_params_list: + - {"pattern": "\\s+\\?", "repl": "?"} + - {"pattern": "\\s+\\.", "repl": "."} + - {"pattern": "\\s+,", "repl": ","} + - {"pattern": "\\s+", "repl": " "} + + - _target_: sdp.processors.datasets.commoncrawl.ManifestToUtf8 + + - _target_: sdp.processors.AddConstantFields + output_manifest_file: ${workspace_dir}/manifest_${lang}.json + fields: {"lang": '${lang}'} + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + output_manifest_file: ${workspace_dir}/manifest_${lang}_train.json + lang: ${lang} + data_split: train + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_train/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_train.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_dev.json + lang: ${lang} + data_split: dev + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_dev/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_dev.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ + + - _target_: sdp.processors.datasets.commoncrawl.TrainDevTestSplitCC + input_manifest_file: ${workspace_dir}/manifest_${lang}.json + output_manifest_file: ${workspace_dir}/manifest_${lang}_test.json + lang: ${lang} + data_split: test + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${base_dir}/splited_manifests/${lang}_test/ + path_levels: 2 + + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${base_dir}/splited_manifests/${lang}_test.json + path_key: audio_filepath + abs_path_to_drop: ${base_dir}/splited_manifests/ \ No newline at end of file diff --git a/dataset_configs/commoncrawl/big_sentence.yaml b/dataset_configs/commoncrawl/big_sentence.yaml new file mode 100644 index 00000000..173ed633 --- /dev/null +++ b/dataset_configs/commoncrawl/big_sentence.yaml @@ -0,0 +1,102 @@ +processors_to_run: "0:" +workspace_dir: /mnt/md1/common_crawl/cc_sdp +workspace_dir_s: /mnt/md0/common_crawl/cc_sdp + +processors: + - _target_: sdp.processors.datasets.commoncrawl.CreateInitialManifestCC + raw_data_dir: /mnt/md0/common_crawl/output/video_output2 + output_manifest_file: ${workspace_dir_s}/manifest0.json + video_key: "source_video" + text_key: "texts" + id_key: "key" + + - _target_: sdp.processors.datasets.commoncrawl.ReadParquet + raw_data_dir: /mnt/md0/common_crawl/output/video_output2 + output_manifest_file: ${workspace_dir_s}/manifest1.json + output_video_key: video_url + output_caption_key: caption_url + id_key: key + + - _target_: sdp.processors.FfmpegConverts + output_manifest_file: ${workspace_dir_s}/manifest2.json #${workspace_dir_s}/manifest_urls.json + resampled_audio_dir: ${workspace_dir}/audio + target_samplerate: 16000 + target_nchannels: 1 + input_file_key: "source_video" + output_file_key: "source_audio" + id_key: "key" + + - _target_: sdp.processors.GetAudioDuration + output_manifest_file: ${workspace_dir_s}/manifest3.json + audio_file_key: source_audio + duration_key: duration + + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${workspace_dir_s}/manifest4.json + input_value_key: duration + target_value: 0 + operator: gt + + - _target_: sdp.processors.datasets.commoncrawl.TxtToVtt + output_manifest_file: ${workspace_dir_s}/manifest5.json + vtt_files_dir: ${workspace_dir_s}/vtts + id_key: "key" + text_key: "texts" + vtt_key: "vtt_filepath" + + - _target_: sdp.processors.datasets.commoncrawl.AllVttText + output_manifest_file: ${workspace_dir_s}/manifest6.json + input_filepath_key: vtt_filepath + output_text_key: vtt_text + + - _target_: sdp.processors.datasets.commoncrawl.TextLid + output_manifest_file: ${workspace_dir_s}/manifest7.json + input_text_key: vtt_text + output_lang_key: text_lang + device: cuda + pretrained_model: "jb2k/bert-base-multilingual-cased-language-detection" + drop_text_duplicates: True + + - _target_: sdp.processors.datasets.commoncrawl.Lang2Iso + output_manifest_file: ${workspace_dir_s}/manifest8.json + input_lang_key: text_lang + output_lang_key: text_lang + + - _target_: sdp.processors.datasets.commoncrawl.AudioLid + output_manifest_file: ${workspace_dir_s}/manifest9.json + input_audio_key: source_audio + output_lang_key: audio_lang + device: cuda + pretrained_model: "langid_ambernet" + + - _target_: sdp.processors.datasets.commoncrawl.SplitByVttSentence + output_manifest_file: ${workspace_dir_s}/manifest10.json + splited_audio_dir: ${workspace_dir_s}/splited/ + source_audio_key: source_audio + target_audio_key: audio_filepath + duration_key: duration + text_key: text + vtt_key: vtt_filepath + proxy_keys: [audio_lang, text_lang, source_audio] + duration_threshold: 10.0 + + - _target_: sdp.processors.DropHighLowDuration + output_manifest_file: ${workspace_dir_s}/manifest11.json + high_duration_threshold: 60 + low_duration_threshold: 0.01 + + - _target_: sdp.processors.KeepOnlySpecifiedFields + output_manifest_file: ${workspace_dir_s}/manifest12.json + fields_to_keep: ["audio_filepath", "duration", "text", "audio_lang", "text_lang", "source_audio"] + + - _target_: sdp.processors.datasets.commoncrawl.EvalBandwidth + input_manifest_file: ${workspace_dir_s}/manifest5.json + output_manifest_file: ${workspace_dir_s}/manifest5a.json + input_file_key: source_audio + bandwidth_key: bandwidth + + - _target_: sdp.processors.datasets.commoncrawl.GetSpecificFiles + input_manifest_file: ${workspace_dir_s}/manifest6.json + output_manifest_file: ${workspace_dir_s}/long_dev_test/manifest6.json + input_file_key: source_audio + path_to_copy: ${workspace_dir_s}/long_dev_test \ No newline at end of file diff --git a/docs/src/sdp/existing_configs.rst b/docs/src/sdp/existing_configs.rst index 3b6b5e67..0bece2d0 100644 --- a/docs/src/sdp/existing_configs.rst +++ b/docs/src/sdp/existing_configs.rst @@ -301,4 +301,4 @@ UzbekVoice .. toctree:: :hidden: - config-docs/uzbek/uzbekvoice/config \ No newline at end of file + config-docs/uzbek/uzbekvoice/config diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 23079d84..01ec2ce9 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -24,9 +24,6 @@ from sdp.processors.datasets.fleurs.create_initial_manifest import ( CreateInitialManifestFleurs, ) -from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( - CreateInitialManifestUzbekvoice, -) from sdp.processors.datasets.ksc2.create_initial_manifest import ( CreateInitialManifestKSC2, ) @@ -51,13 +48,20 @@ CreateInitialManifestSLR140, CustomDataSplitSLR140, ) +from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( + CreateInitialManifestUzbekvoice, +) from sdp.processors.datasets.voxpopuli.create_initial_manifest import ( CreateInitialManifestVoxpopuli, ) from sdp.processors.datasets.voxpopuli.normalize_from_non_pc_text import ( NormalizeFromNonPCTextVoxpopuli, ) -from sdp.processors.huggingface.speech_recognition import ASRTransformers +from sdp.processors.huggingface.speech_recognition import ( + ASRTransformers, + ASRWhisper, + LangIdWhisper, +) from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, @@ -74,6 +78,10 @@ CountNumWords, FfmpegConvert, GetAudioDuration, + GetCER, + GetEdgeCER, + GetLenDiffRatio, + GetWER, InsIfASRInsertion, InverseNormalizeText, NormalizeText, @@ -98,11 +106,11 @@ DropLowWordMatchRate, DropNonAlphabet, DropOnAttribute, - PreserveByValue, DropRepeatedFields, + PreserveByValue, ) from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, ) -from sdp.processors.nemo.asr_inference import ASRInference +from sdp.processors.nemo.asr_inference import ASRInference, ASRInferenceParallel from sdp.processors.nemo.pc_inference import PCInference diff --git a/sdp/processors/datasets/commoncrawl/__init__.py b/sdp/processors/datasets/commoncrawl/__init__.py index 3d8406bf..655e5893 100644 --- a/sdp/processors/datasets/commoncrawl/__init__.py +++ b/sdp/processors/datasets/commoncrawl/__init__.py @@ -12,4 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .commoncrawl import SplitByVttSentence \ No newline at end of file +from .commoncrawl import ( + ASR_HF, + AlignerSubprocess, + AllVttText, + AudioLid, + BLEUScore, + CopyFiles, + CreateInitialManifestCC, + DropAbsPath, + EvalBandwidth, + GetSpecificFiles, + JoinBy, + Lang2Iso, + ManifestToUtf8, + NmtSubprocess, + ReadParquet, + SplitByAligner, + SplitByVtt, + SplitByVttSentence, + Subprocess, + TextLid, + TrainDevTestSplitCC, + TxtToVtt, + UseSonar, +) diff --git a/sdp/processors/datasets/commoncrawl/commoncrawl.py b/sdp/processors/datasets/commoncrawl/commoncrawl.py index 8a5cc2c6..101f49ed 100644 --- a/sdp/processors/datasets/commoncrawl/commoncrawl.py +++ b/sdp/processors/datasets/commoncrawl/commoncrawl.py @@ -1,10 +1,1153 @@ +import json +import math import os -from typing import List +import re +import shutil +import subprocess +from pathlib import Path +from typing import Dict, List, Union +import librosa +import numpy as np +import pandas as pd import soundfile as sf -from sdp.processors.base_processor import BaseParallelProcessor, DataEntry -from sdp.processors.datasets.commoncrawl.harv_utils import split_by_vtt +from sacrebleu import BLEU +from scipy.spatial import distance +from tqdm import tqdm +from sdp.logging import logger +from sdp.processors.base_processor import ( + BaseParallelProcessor, + BaseProcessor, + DataEntry, +) +from sdp.processors.datasets.commoncrawl.harv_utils import ( + audio_duration, + get_vtt_text, + load_manifest, + make_trans_list, + read_jsonl, + split_by_vtt, + split_by_vtt_new, + text2lid, + txt2vtt, + write_jsonl, +) +from sdp.processors.datasets.youtube.utils import Sample, parse_srt + + +class ManifestToUtf8(BaseProcessor): + """ + Processor to convert manifest file to UTF-8 encoding. + """ + + def process(self): + with open(self.output_manifest_file, "w") as wout, open(self.input_manifest_file) as win: + for line in win: + print(json.dumps(json.loads(line), ensure_ascii=False), file=wout) + + +class DropAbsPath(BaseParallelProcessor): + """ + Drop absolute path + + Args: + path_key (str): where to get path to wav file. + abs_path_to_drop (str): string to drop from the bigining of path to wav file. + """ + + def __init__( + self, + path_key: str, + abs_path_to_drop: str, + **kwargs, + ): + super().__init__(**kwargs) + self.path_key = path_key + self.abs_path_to_drop = abs_path_to_drop + + def process_dataset_entry(self, data_entry): + audio_filepath = data_entry[self.path_key] + data_entry[self.path_key] = audio_filepath[len(self.abs_path_to_drop) :] + return [DataEntry(data=data_entry)] + + +class CopyFiles(BaseParallelProcessor): + def __init__( + self, + file_field: str, + path_to_copy: str, + path_levels: str = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.file_field = file_field + self.path_to_copy = path_to_copy + self.path_levels = path_levels + + def prepare(self): + os.makedirs(self.path_to_copy, exist_ok=True) + + def process_dataset_entry(self, data_entry): + rel_file_path = "/".join(data_entry[self.file_field].split("/")[-self.path_levels :]) + new_file_path = os.path.join(self.path_to_copy, rel_file_path) + + if not os.path.isfile(new_file_path): + os.makedirs(os.path.split(new_file_path)[0], exist_ok=True) + shutil.copyfile(data_entry[self.file_field], new_file_path) + data_entry[self.file_field] = new_file_path + return [DataEntry(data=data_entry)] + + +class GetSpecificFiles(BaseParallelProcessor): + def __init__( + self, + input_file_key: str, + path_to_copy: str, + **kwargs, + ): + super().__init__(**kwargs) + self.input_file_key = input_file_key + self.path_to_copy = path_to_copy + + self.split_map = set( + [ + '0634236', + '0693626', + '0029743', + '0881322', + '0357427', + '0455788', + '0198472', + '0496259', + '0812890', + '0142281', + '0076612', + '0629004', + '0931592', + '0577447', + '0768107', + '0907768', + '0963898', + '0671754', + '0851569', + '0896715', + '0366790', + '0837221', + '0733702', + '0278253', + '0738313', + '0437256', + '0558223', + '0292533', + '0777911', + '0826607', + '0544257', + '0744206', + '0576248', + '0307575', + '0307577', + '0879895', + '0006783', + '0006755', + '0125649', + '0896701', + ] + ) + + def prepare(self): + os.makedirs(self.path_to_copy, exist_ok=True) + + def process_dataset_entry(self, data_entry): + file_id = os.path.splitext(data_entry[self.input_file_key])[0].split("/")[-1] + if file_id in self.split_map: + shutil.copyfile(data_entry[self.input_file_key], os.path.join(self.path_to_copy, file_id + ".wav")) + return [DataEntry(data=data_entry)] + else: + return [] + + +class TrainDevTestSplitCC(BaseParallelProcessor): + """Custom train-dev-test split for CORAAL dataset. + + Split is done speaker-wise, so the same speakers don't appear in different + splits. + + Args: + data_split (str): train, dev or test. + lang (str): language to process. + + Returns: + All the same fields as in the input manifest, but only a subset of + the data is retained. + """ + + def __init__( + self, + data_split: str, + lang: str, + **kwargs, + ): + super().__init__(**kwargs) + if data_split not in ["train", "dev", "test"]: + raise ValueError("data_split has to be either train, dev or test") + self.data_split = data_split + self.lang = lang + + self.split_map = {} + self.split_map["en"] = {} + self.split_map["en"]["dev"] = set( + [ + '0634236', + '0693626', + '0029743', + '0881322', + '0357427', + '0455788', + '0198472', + '0496259', + '0812890', + '0142281', + '0076612', + '0629004', + '0931592', + '0577447', + '0768107', + '0907768', + '0963898', + '0671754', + '0851569', + '0896715', + ] + ) + self.split_map["en"]["test"] = set( + [ + '0366790', + '0837221', + '0733702', + '0278253', + '0738313', + '0437256', + '0558223', + '0292533', + '0777911', + '0826607', + '0544257', + '0744206', + '0576248', + '0307575', + '0307577', + '0879895', + '0006783', + '0006755', + '0125649', + '0896701', + ] + ) + self.split_map["de"] = {} + self.split_map["de"]["dev"] = set( + [ + '0383522', + '0327835', + '0327898', + '0619871', + '0387103', + '0854766', + '0738911', + '0739038', + '0854558', + '0505561', + '0735963', + '0086041', + '0967593', + '0114210', + '0098270', + '0387140', + '0917035', + '0327745', + '0914212', + '0739071', + ] + ) + self.split_map["de"]["test"] = set( + [ + '0076939', + '0589098', + '0916988', + '0268959', + '0085896', + '0327813', + '0085897', + '0739103', + '0502188', + '0034822', + '0327729', + '0572412', + '0327680', + '0027277', + '0324720', + '0209876', + '0027226', + '0268926', + '0209776', + '0738970', + ] + ) + self.split_map["pl"] = {} + self.split_map["pl"]["dev"] = set( + [ + '0977373', + '0949141', + '0455759', + '0357429', + '0401864', + '0714974', + '0422716', + '0363476', + '0714976', + '0927100', + ] + ) + self.split_map["pl"]["test"] = set( + [ + '0157903', + '0115644', + '0774572', + '0688432', + '0258376', + '0396163', + '0456013', + '0571489', + '0157653', + '0062567', + ] + ) + self.split_map["fr"] = {} + self.split_map["fr"]["dev"] = set( + [ + '0588135', + '0706751', + '0533213', + '0920924', + '0355413', + '0985711', + '0113477', + '0533044', + '0089551', + '0944509', + '0944576', + '0766533', + '0263084', + '0113490', + '0647104', + '0273918', + '0473607', + '0706753', + '0800223', + '0300105', + '0944416', + '0566712', + '0533102', + '0177064', + '0029651', + '0215767', + '0054412', + '0236920', + '0885068', + '0296098', + '0113592', + '0706610', + '0473383', + '0330163', + '0681542', + '0272523', + '0985709', + '0564446', + '0944481', + '0587986', + '0804060', + '0236908', + '0969694', + '0054058', + '0800671', + '0236923', + '0986025', + '0770086', + '0825692', + '0968870', + '0152315', + '0533147', + '0647027', + '0029342', + '0272698', + '0153863', + '0355323', + '0988779', + '0985959', + '0237013', + '0338134', + '0885097', + '0507678', + '0507687', + '0944485', + '0825768', + '0742440', + '0969664', + '0885089', + '0117211', + '0296044', + '0985958', + '0214384', + '0021267', + '0565392', + '0388467', + '0151715', + '0861950', + '0112768', + '0113596', + '0621657', + '0236860', + '0647128', + '0058479', + '0803614', + '0177501', + '0533110', + '0566787', + '0944496', + '0859701', + '0885165', + '0212639', + '0054532', + '0919263', + '0740701', + ] + ) + self.split_map["fr"]["test"] = set( + [ + '0473649', + '0390470', + '0296024', + '0355365', + '0314592', + '0682498', + '0534637', + '0270580', + '0532999', + '0373977', + '0622032', + '0825761', + '0923303', + '0113485', + '0825868', + '0473710', + '0511698', + '0844353', + '0801733', + '0091695', + '0452351', + '0825872', + '0969173', + '0986055', + '0970208', + '0141266', + '0149629', + '0296117', + '0153112', + '0801752', + '0030816', + '0508766', + '0029390', + '0825877', + '0271152', + '0388655', + '0743376', + '0177466', + '0153032', + '0329945', + '0473606', + '0986015', + '0096178', + '0089561', + '0440564', + '0741466', + '0499703', + '0272514', + '0944571', + '0919512', + '0646950', + '0533215', + '0760703', + '0733028', + '0113488', + '0825739', + '0492402', + '0214463', + '0154278', + '0801877', + '0825675', + '0675029', + '0801729', + '0414446', + '0054425', + '0279176', + '0296100', + '0355317', + '0733026', + '0089548', + '0177502', + '0851638', + '0851640', + '0448606', + '0803096', + '0766603', + '0507914', + '0092173', + '0647061', + '0473564', + '0706765', + '0766538', + '0295994', + '0851630', + '0029358', + '0647062', + '0825838', + '0153786', + '0944526', + '0944484', + '0588046', + '0706820', + '0177465', + '0622092', + '0332657', + '0944480', + ] + ) + + def process_dataset_entry(self, data_entry): + file_id = os.path.splitext(data_entry["audio_filepath"])[0].split("/")[-2] + if self.data_split == "train": + if file_id not in self.split_map[self.lang]["dev"] and file_id not in self.split_map[self.lang]["test"]: + return [DataEntry(data=data_entry)] + else: + if file_id in self.split_map[self.lang][self.data_split]: + return [DataEntry(data=data_entry)] + return [] + + +class JoinBy(BaseProcessor): + """ + This processor join several lines into one using key input_field + + Args: + input_field (str): where to get path to wav file. + text_field (str): where to put resulted text. + audio_field (str): where to put resulted wav file. + + Returns: + All the same fields as in the input manifest plus audio_field + """ + + def __init__( + self, + input_field: str, + text_field: str = "text", + audio_field: str = 'audio_filepath', + **kwargs, + ): + super().__init__(**kwargs) + self.input_field = input_field + self.text_field = text_field + self.audio_field = audio_field + + def process(self): + df1 = read_jsonl(self.input_manifest_file) + pattern = re.compile("\s{2,}") + df1[self.text_field] = df1[self.text_field].apply(lambda x: pattern.sub(" ", x).strip()) + # df1["source"] = df1["audio_filepath"].apply(lambda x: x.split("/")[-2]) + + df2 = pd.DataFrame( + df1.groupby(self.input_field).apply(lambda in_df: " ".join(in_df[self.text_field].tolist())), + columns=[self.text_field], + ).reset_index() + df2[self.audio_field] = df2[self.input_field] + write_jsonl(df2[[self.audio_field, self.text_field]], self.output_manifest_file) + + +class EvalBandwidth(BaseParallelProcessor): + """ + Count audio bandwidth using audio file path from input_field + + Args: + input_file_key (str): where to get path to wav file. + bandwidth_key (str): where to put to frequency bandwidth. + threshold (str): power threshold (in dB relative to peak power in spectrum bin) to estimate frequency bandwidth. + + Returns: + All the same fields as in the input manifest plus output_field. + """ + + def __init__( + self, + input_file_key: str, + bandwidth_key: str, + threshold: int = -50, + **kwargs, + ): + super().__init__(**kwargs) + self.input_file_key = input_file_key + self.bandwidth_key = bandwidth_key + self.threshold = threshold + + def process_dataset_entry(self, data_entry): + audio_filepath = data_entry[self.input_file_key] + data, samplerate = sf.read(audio_filepath) + freqband = self.eval_bandwidth(data, samplerate, threshold=self.threshold) + data_entry[self.bandwidth_key] = freqband + return [DataEntry(data=data_entry)] + + def eval_bandwidth(self, signal, sr, threshold=-50): + time_stride = 0.01 + hop_length = int(sr * time_stride) + n_fft = 512 + spectrogram = np.mean( + np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1 + ) + power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100) + freqband = 0 + for idx in range(len(power_spectrum) - 1, -1, -1): + if power_spectrum[idx] > threshold: + freqband = idx / n_fft * sr + break + return freqband + + +class SplitByAligner(BaseParallelProcessor): + """ + Split wav file using NFA aligner fields: nfa_start, nfa_duration + + Args: + input_field (str): field to get source wav file names. + output_field: (str): field to put splited wav file names. + splited_audio_dir (str): where to save splited wav files. + Returns: + All the same fields as in the input manifest plus output_field. + """ + + def __init__( + self, + input_field: str, + output_field: str, + splited_audio_dir: str, + **kwargs, + ): + super().__init__(**kwargs) + self.input_field = input_field + self.output_field = output_field + self.splited_audio_dir = splited_audio_dir + + def prepare(self): + os.makedirs(self.splited_audio_dir, exist_ok=True) + + def process_dataset_entry(self, data_entry): + audio_filepath = data_entry[self.input_field] + + # print(data_entry) + data, samplerate = sf.read(audio_filepath) + nfa_start = data_entry["nfa_start"] + nfa_duration = data_entry["nfa_duration"] + + if math.isnan(nfa_start) or math.isnan(nfa_duration) or math.isnan(samplerate): + print(audio_filepath, nfa_start, nfa_duration) + data_entry[self.output_field] = data_entry['audio_filepath'] + else: + start = int(nfa_start * samplerate) + duration = int(nfa_duration * samplerate) + + data_sample = data[start : start + duration] + + wav_save_file = os.path.join( + self.splited_audio_dir, + '/'.join(os.path.splitext(audio_filepath)[0].split('/')[-2:]), + str(int(start * 1000 / samplerate)) + "-" + str(int((start + duration) * 1000 / samplerate)) + ".wav", + ) + if not os.path.isfile(wav_save_file): + os.makedirs(os.path.split(wav_save_file)[0], exist_ok=True) + sf.write(wav_save_file, data_sample, samplerate) + data_entry[self.output_field] = wav_save_file + return [DataEntry(data=data_entry)] + + +class ASR_HF(BaseProcessor): + """ + Transcribe usinf ASR model from HuggingFace. + + Args: + pretrained_model (str): name of pretrained model on HuggingFace. + output_text_field (str): field to save transcription result. + device (str): Inference device. + batch_size (str): Inference batch size. + Returns: + All the same fields as in the input manifest plus output_text_field. + """ + + def __init__( + self, + pretrained_model: str, + output_text_field: str, + device: str = None, + batch_size: str = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.pretrained_model = pretrained_model + self.output_text_field = output_text_field + self.device = device + self.batch_size = batch_size + + def process(self): + import torch + from huggingsound import SpeechRecognitionModel + + if self.device is None: + if torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + + model = SpeechRecognitionModel(self.pretrained_model, device=self.device, letter_case=None) + + manifest, key_dict = load_manifest(Path(self.input_manifest_file), keys=["audio_filepath"]) + audio_paths = key_dict["audio_filepath"] + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + + transcriptions = model.transcribe(paths=audio_paths, batch_size=self.batch_size, decoder=None) + + with Path(self.output_manifest_file).open('w') as f: + for item, transcription in tqdm(zip(manifest, transcriptions)): + item[self.output_text_field] = transcription["transcription"] + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + +class UseSonar(BaseProcessor): + """ + Count vector distance using Sonar library. + + Args: + input_text_field (str): field with text to process. + input_audio_field (str): field with audio file path to process. + output_field (str): field to save distance. + speech_encoder_model (str): name of pretrained speech encoder model. + text_encoder_lang (str): language of text. + text_encoder_model (str): name of pretrained text encoder model. + batch_size (int): batch size for inference. + device (str): device to inference on it. + Returns: + All the same fields as in the input manifest plus output_field. + """ + + def __init__( + self, + input_text_field: str, + input_audio_field: str, + output_field: str, + speech_encoder_model: str, + text_encoder_lang: str, + text_encoder_model: str, + batch_size: int = 64, + device: str = "cuda", + **kwargs, + ): + super().__init__(**kwargs) + import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo + from sonar.inference_pipelines.speech import SpeechToEmbeddingModelPipeline + from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline + from sonar.models.sonar_speech.loader import load_sonar_speech_model + from sonar.models.sonar_text import ( + load_sonar_text_decoder_model, + load_sonar_text_encoder_model, + load_sonar_tokenizer, + ) + from torch.nn import PairwiseDistance + + self.output_field = output_field + self.input_text_field = input_text_field + self.input_audio_field = input_audio_field + self.batch_size = batch_size + self.device = device + self.text_encoder_lang = text_encoder_lang + self.text_encoder_model = load_sonar_text_encoder_model(text_encoder_model, device=self.device).eval() + self.text_tokenizer = load_sonar_tokenizer(text_encoder_model) + self.speech_encoder_model = load_sonar_speech_model(speech_encoder_model, device=self.device).eval() + self.pdist = PairwiseDistance(p=2) + self.s2vec_model = SpeechToEmbeddingModelPipeline(encoder=self.speech_encoder_model) + self.text_embedding_pipeline = TextToEmbeddingModelPipeline(self.text_encoder_model, self.text_tokenizer) + + def process(self): + manifest = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(manifest): + input_texts = [item[self.input_text_field]] + input_audios = [item[self.input_audio_field]] + dist = self.get_pdist(input_texts, input_audios) + item[self.output_field] = dist + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + def get_pdist(self, input_texts, input_audios): + text_emb = self.text_embedding_pipeline.predict( + input=input_texts, batch_size=1, source_lang=self.text_encoder_lang + ) + + audio_emb = self.s2vec_model.predict( + input=input_audios, + batch_size=1, + n_parallel=1, + pad_idx=0, + n_prefetched_batches=1, + ) + # pdist = self.pdist(text_emb, audio_emb).numpy().squeeze().astype(float).tolist() + pdist = ( + distance.cdist(text_emb.numpy().astype(float), audio_emb.numpy().astype(float), 'sqeuclidean') + .squeeze() + .tolist() + ) + return pdist + + def process_batch(self): + manifest, dict_list = load_manifest( + Path(self.input_manifest_file), keys=[self.input_audio_field, self.input_text_field] + ) + manifest_len = len(manifest) + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with Path(self.output_manifest_file).open('w') as f: + for start in tqdm(range(0, manifest_len, self.batch_size)): + stop = start + self.batch_size + input_texts = dict_list[self.input_text_field][start:stop] + input_audios = dict_list[self.input_audio_field][start:stop] + manifest_batch = manifest[start:stop] + + dists = self.get_pdist(input_texts, input_audios) + for item, dist in zip(manifest_batch, dists): + item[self.output_field] = dist + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + +class BLEUScore(BaseParallelProcessor): + """ + Count BLEU Score. + + Args: + ref_field (str): field with reference texts + hyp_field (str): field with hypotheses + output_field (str): field to save BLEU Score + Returns: + All the same fields as in the input manifest plus output_field. + """ + + def __init__( + self, + ref_field: str, + hyp_field: str, + output_field: str, + **kwargs, + ): + super().__init__(**kwargs) + self.ref_field = ref_field + self.hyp_field = hyp_field + self.output_field = output_field + self.scorer = BLEU(effective_order=True) + + def process_dataset_entry(self, data_entry): + ref = data_entry[self.ref_field] + hyp = data_entry[self.hyp_field] + + res = self.scorer.sentence_score(hypothesis=hyp, references=[ref]) + data_entry[self.output_field] = res.score + return [DataEntry(data=data_entry)] + + +class Subprocess(BaseProcessor): + """ + Processor for handling subprocess execution with additional features for managing input and output manifests. + + Args: + cmd (str): The command to be executed as a subprocess. + input_manifest_arg (str, optional): The argument specifying the input manifest. Defaults to an empty string. + output_manifest_arg (str, optional): The argument specifying the output manifest. Defaults to an empty string. + arg_separator (str, optional): The separator used between argument and value. Defaults to "=". + shell (bool, optional): The argument specifies whether to use shell for subprocess.run(). Defaults to False. + dont_wait (bool, optional): The argument specifies whether to wait while the subprocess finishes. . Defaults to False. + **kwargs: Additional keyword arguments to be passed to the base class. + + Example: + + _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: /workspace/manifest.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /workspace/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=en --n_jobs=-1 --batch_size=600 --manifest_text_field=text --cache_dir=${workspace_dir}/cache --overwrite_cache \ + --whitelist=/workspace/NeMo-text-processing/nemo_text_processing/text_normalization/en/data/whitelist/asr_with_pc.tsv" + """ + + def __init__( + self, + cmd: str, + input_manifest_arg: str = "", + output_manifest_arg: str = "", + arg_separator: str = "=", + shell: bool = False, + dont_wait: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.input_manifest_arg = input_manifest_arg + self.output_manifest_arg = output_manifest_arg + self.arg_separator = arg_separator + self.cmd = cmd + self.shell = shell + self.dont_wait = dont_wait + + def process(self): + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + if (self.cmd.find(self.input_manifest_file) != -1 and self.input_manifest_arg != "") \ + or (self.cmd.find(self.output_manifest_file) != -1 and self.output_manifest_arg != ""): + raise ValueError("input_manifest_file " + + self.input_manifest_file + + " and output_manifest_file " + + self.output_manifest_file + + " should be exluded from cmd line: " + + self.cmd) + process_args = [x for x in self.cmd.split(" ") if x] + if self.arg_separator == " ": + if self.input_manifest_arg: + process_args.extend([self.input_manifest_arg, self.input_manifest_file]) + if self.output_manifest_arg: + process_args.extend([self.output_manifest_arg, self.output_manifest_file]) + else: + if self.input_manifest_arg: + process_args.extend([self.input_manifest_arg + self.arg_separator + self.input_manifest_file]) + if self.output_manifest_arg: + process_args.extend([self.output_manifest_arg + self.arg_separator + self.output_manifest_file]) + if self.shell: + process_args = " ".join(process_args) + logger.info("subprocess shell: " + process_args) + + if self.dont_wait: + logger.warning("dont_wait flag is True, no logs captures!") + subprocess.Popen(process_args, shell=self.shell, stdin=None, stdout=None, stderr=None, close_fds=True) + else: + subprocess.run(process_args, shell=self.shell) + + +class NmtSubprocess(Subprocess): + """ + A class for executing Neural Machine Translation (NMT) subprocess with enhanced functionality for managing input and output fields. + + Parameters: + input_field (str): The field in the input manifest containing the source text for translation. + output_field (str): The field to store the translated output in the output manifest. + srctext_file (str): The file path to store the source text for translation. + tgtout_file (str): The file path to store the translated output. + **kwargs: Additional keyword arguments to be passed to the base class `Subprocess`. + + """ + + def __init__( + self, + input_field: str, + output_field: str, + srctext_file: str, + tgtout_file: str, + **kwargs, + ): + super().__init__(**kwargs) + self.input_field = input_field + self.output_field = output_field + self.srctext_file = srctext_file + self.tgtout_file = tgtout_file + self.cmd = ( + self.cmd + + " --srctext" + + self.arg_separator + + self.srctext_file + + " --tgtout" + + self.arg_separator + + self.tgtout_file + ) + + def process(self): + df1 = read_jsonl(self.input_manifest_file) + with Path(self.srctext_file).open('w') as f: + for input_field in df1[self.input_field]: + f.write(input_field + "\n") + + super().process() + + with Path(self.tgtout_file).open('r') as f: + tgtout = [l.strip() for l in f] + df1[self.output_field] = tgtout + write_jsonl(df1, self.output_manifest_file) + + +class AlignerSubprocess(Subprocess): + """ + A class for aligning audio transcripts using an aligner subprocess with additional features for managing output fields. + + Parameters: + output_field (str): The field in the output manifest to store the aligned transcripts. + duration_threshold (int, optional): The maximum duration threshold for audio files in seconds. Files exceeding this threshold are excluded from alignment. Defaults to 5000. + **kwargs: Additional keyword arguments to be passed to the base class `Subprocess`. + + """ + + def __init__( + self, + output_field: str, + duration_threshold: int = 5000, + **kwargs, + ): + super().__init__(**kwargs) + self.output_field = output_field + self.duration_threshold = duration_threshold + + def process(self): + df1 = read_jsonl(self.input_manifest_file) + pattern = re.compile("\s{2,}") + df1["text"] = df1["text"].apply(lambda x: pattern.sub(" ", x).strip()) + df1["source"] = df1["audio_filepath"].apply(lambda x: x.split("/")[-2]) + + df2 = pd.DataFrame( + df1.groupby("source_audio").apply(lambda in_df: "|".join(in_df["text"].tolist())), columns=["text"] + ).reset_index() + df2['audio_filepath'] = df2['source_audio'] + df2['duration'] = df2['audio_filepath'].apply(audio_duration) + df2 = df2[df2['duration'] < self.duration_threshold] + + self.input_manifest_file = os.path.join(os.path.split(self.input_manifest_file)[0], 'tmp.json') + write_jsonl(df2[['audio_filepath', 'text']], self.input_manifest_file) + + super().process() + manifest_path, manifest_name = os.path.split(self.input_manifest_file) + manifest_name = os.path.splitext(manifest_name)[0] + aligner_path = os.path.join(manifest_path, manifest_name + "_with_output_file_paths.json") + df3 = read_jsonl(aligner_path) + pattern = re.compile("") + df4 = pd.DataFrame() + + for ctm_filepath in tqdm(df3["segments_level_ctm_filepath"]): + source = os.path.splitext(ctm_filepath)[0].split('/')[-1] + df6 = df1[df1["source"] == source].reset_index() + df5 = pd.read_csv(ctm_filepath, sep=' ', header=None, dtype={0: str}) + df5["text"] = df5[4].apply(lambda x: pattern.sub(" ", x)) + df5["nfa_start"] = df5[2] + df5["nfa_duration"] = df5[3] + if df5.shape[0] == df6.shape[0]: + df7 = df5[["nfa_start", "nfa_duration", "text"]].merge(df6, how="right") + else: + raise ValueError(ctm_filepath) + + df4 = pd.concat([df4, df7]) + + write_jsonl(df4, self.output_manifest_file) + + +class Lang2Iso(BaseParallelProcessor): + """ + A class for converting language names to ISO language codes in a dataset. + + Parameters: + input_lang_key (str): The field in the dataset containing language names to be converted. + output_lang_key (str): The field to store the corresponding ISO language codes. + + """ + + def __init__( + self, + input_lang_key: str, + output_lang_key: str, + **kwargs, + ): + super().__init__(**kwargs) + self.input_lang_key = input_lang_key + self.output_lang_key = output_lang_key + self.iso_m = { + 'English': 'en', + 'Spanish': 'es', + 'Basque': 'eu', + 'Dutch': 'nl', + 'Welsh': 'cy', + 'Italian': 'it', + 'Catalan': 'ca', + 'Maltese': 'mt', + 'Swedish': 'sv', + 'French': 'fr', + 'German': 'de', + 'Chuvash': 'cv', + 'Kinyarwanda': 'rw', + 'Polish': 'pl', + 'Kabyle': 'kab', + 'Interlingua': 'ua', + 'Portuguese': 'pt', + 'Hakha_Chin': 'cnh', + 'Romansh_Sursilvan': 'roh', + 'Breton': 'br', + 'Esperanto': 'epo', + 'Czech': 'ces', + 'Latvian': 'lav', + 'Indonesian': 'ind', + 'Slovenian': 'slv', + 'Turkish': 'tur', + 'Frisian': 'frr', + 'Tatar': 'tat', + 'Persian': 'fas', + 'Estonian': 'est', + 'Romanian': 'rum', + 'Chinese_Hongkong': 'zh', + 'Chinese_Taiwan': 'zh', + 'Chinese_China': 'zh', + 'Georgian': 'kat', + 'Kyrgyz': 'kir', + 'Dhivehi': 'div', + 'Sakha': 'sah', + 'Arabic': 'ar', + 'Japanese': 'ja', + 'Russian': 'ru', + } + + def process_dataset_entry(self, data_entry): + data_entry[self.output_lang_key] = self.iso_m.get(data_entry[self.input_lang_key], None) + return [DataEntry(data=data_entry)] + + +class SplitByVtt(BaseParallelProcessor): + def __init__( + self, + source_audio_key: str, + caption_file_key: str, + duration_key: str = "duration", + output_text_key: str = "orig_text", + **kwargs, + ): + super().__init__(**kwargs) + self.source_audio_key = source_audio_key + self.duration_key = duration_key + self.output_text_key = output_text_key + self.caption_file_key = caption_file_key + + def process_dataset_entry(self, data_entry): + caption_file = data_entry[self.caption_file_key] + audio_file = data_entry[self.source_audio_key] + if not os.path.exists(audio_file): + return [] + segments = parse_srt(caption_file, verify_duration=True, wav_filepath=audio_file) + + if len(segments) > 0: + data_entry['segments'] = [segment.__dict__ for segment in segments] + return [DataEntry(data=data_entry)] class SplitByVttSentence(BaseParallelProcessor): @@ -49,13 +1192,13 @@ def prepare(self): os.makedirs(self.splited_audio_dir, exist_ok=True) def process_dataset_entry(self, data_entry): - vtt_file = data_entry[self.vtt_field] - source_audio = data_entry[self.source_audio_field] + caption_file = data_entry[self.caption_file_key] + source_audio = data_entry[self.source_audio_key] res_list = [] if os.path.isfile(source_audio): data, samplerate = sf.read(source_audio) - text_list, start_s, end_s = split_by_vtt(vtt_file, samplerate) + text_list, start_s, end_s = split_by_vtt_new(caption_file, samplerate) text_c = '' start_c, end_c = 0, 0 if text_list: @@ -67,33 +1210,371 @@ def process_dataset_entry(self, data_entry): pass end_c = end_sr if len(text_c) > 0 and ( - end_c - start_c > self.duration_threshold * samplerate or - text_c[-1] == "." or text_c[-1] == "?"): + end_c - start_c > self.duration_threshold * samplerate + or text_c[-1] == "." + or text_c[-1] == "?" + ): res_list.append( - self.makeDataEntry(data_entry, data, vtt_file, samplerate, text_c, start_c, end_c)) + self.makeDataEntry(data_entry, data, caption_file, samplerate, text_c, start_c, end_c) + ) text_c = '' start_c, end_c = 0, 0 else: pass if len(text_c) > 0 and start_c != 0: - res_list.append(self.makeDataEntry(data_entry, data, vtt_file, samplerate, text_c, start_c, end_c)) + res_list.append( + self.makeDataEntry(data_entry, data, caption_file, samplerate, text_c, start_c, end_c) + ) return res_list def makeDataEntry(self, data_entry, data, vtt_file, samplerate, text_c, start_c, end_c): data_sample = data[start_c:end_c] - wav_save_file = os.path.join(self.splited_audio_dir, '/'.join(os.path.splitext(vtt_file)[0].split('/')[-2:]), - str(int(start_c / (samplerate / 1000))) + "-" + str( - int(end_c / (samplerate / 1000))) + ".wav") + wav_save_file = os.path.join( + self.splited_audio_dir, + '/'.join(os.path.splitext(vtt_file)[0].split('/')[-2:]), + str(int(start_c / (samplerate / 1000))) + "-" + str(int(end_c / (samplerate / 1000))) + ".wav", + ) if not os.path.isfile(wav_save_file): os.makedirs(os.path.split(wav_save_file)[0], exist_ok=True) sf.write(wav_save_file, data_sample, samplerate) - data = {self.target_audio_field: wav_save_file, - self.duration_field: data_sample.shape[0] / samplerate, - self.text_field: text_c.strip(), - } - for field in self.additional_fields: - data[field] = data_entry[field] + data = { + self.target_audio_key: wav_save_file, + self.duration_key: data_sample.shape[0] / samplerate, + self.text_key: text_c.strip(), + } + for proxy_key in self.proxy_keys: + data[proxy_key] = data_entry[proxy_key] return DataEntry(data=data) + +class AudioLid(BaseProcessor): + """ + A class for language identification (LID) of audio files using a pre-trained LID model. + + Args: + input_audio_key (str): The field in the dataset containing the path to the audio files for language identification. + pretrained_model (str): The name of the pre-trained ASR model for language identification. + output_lang_key (str): The field to store the identified language for each audio file. + device (str): The device to run the ASR model on (e.g., 'cuda', 'cpu'). If None, it automatically selects the available GPU if present; otherwise, it uses the CPU. + segment_duration (float): Random sample duration in seconds. Delault is np.inf. + num_segments (int): Number of segments of file to use for majority vote. Delault is 1. + random_seed (int): Seed for generating the starting position of the segment. Delault is None. + **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. + + """ + + def __init__( + self, + input_audio_key: str, + pretrained_model: str, + output_lang_key: str, + device: str, + segment_duration: float = np.inf, + num_segments: int = 1, + random_seed: int = None, + **kwargs, + ): + super().__init__(**kwargs) + self.input_audio_key = input_audio_key + self.pretrained_model = pretrained_model + self.output_lang_key = output_lang_key + self.segment_duration = segment_duration + self.num_segments = num_segments + self.random_seed = random_seed + self.device = device + + def process(self): + import nemo.collections.asr as nemo_asr + import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo + + model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name=self.pretrained_model) + + if self.device is None: + if torch.cuda.is_available(): + model = model.cuda() + else: + model = model.cpu() + else: + model = model.to(self.device) + + manifest = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(manifest): + audio_file = item[self.input_audio_key] + + try: + lang = model.get_label(audio_file, self.segment_duration, self.num_segments) + except Exception as e: + logger.warning("AudioLid " + audio_file + " " + str(e)) + lang = None + + if lang: + item[self.output_lang_key] = lang + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + +class TextLid(BaseProcessor): + """ + A class for language identification (LID) of text using a pre-trained text classification model. + + Args: + input_text_key (str): The field in the dataset containing the text for language identification. + pretrained_model (str): The name or path of the pre-trained text classification model for language identification. + output_lang_key (str): The field to store the identified language for each text. + device (str): The device to run the text classification model on (e.g., 'cuda', 'cpu'). If None, it automatically selects the available GPU if present; otherwise, it uses the CPU. + drop_text_duplicates (bool, optional): If True, drops duplicate texts from the output manifest. Defaults to False. + **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. + + """ + + def __init__( + self, + input_text_key: str, + pretrained_model: str, + output_lang_key: str, + device: str, + drop_text_duplicates: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.input_text_key = input_text_key + self.pretrained_model = pretrained_model + self.output_lang_key = output_lang_key + self.device = device + self.drop_duplicates = drop_text_duplicates + + def process(self): + import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo + from transformers import AutoModelForSequenceClassification, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model) + text_model = AutoModelForSequenceClassification.from_pretrained(self.pretrained_model) + + if self.device is None: + if torch.cuda.is_available(): + text_model = text_model.cuda() + else: + text_model = text_model.cpu() + else: + text_model = text_model.to(self.device) + + manifest = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + text_set = set() + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(manifest): + text = item[self.input_text_key] + if self.drop_duplicates and text not in text_set: + text_set.add(text) + if text: + lid = text2lid(text_model, tokenizer, text) + else: + lid = None + + if lid: + item[self.output_lang_key] = lid + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + +class AllVttText(BaseParallelProcessor): + """ + A class for extracting text content from VTT (WebVTT) files and updating the manifest. + + Args: + output_text_key (str): The field to store the extracted text content in the manifest. + input_filepath_key (str, optional): The field in the manifest containing the path to VTT files. Defaults to "vtt_filepath". + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + + Methods: + process_dataset_entry(data_entry): Processes a single dataset entry, extracts text content from the specified VTT file, and updates the manifest. + + """ + + def __init__( + self, + output_text_key: str, + input_filepath_key: str = "vtt_filepath", + **kwargs, + ): + super().__init__(**kwargs) + self.output_text_key = output_text_key + self.input_filepath_key = input_filepath_key + + def process_dataset_entry(self, data_entry): + vtt_file = data_entry[self.input_filepath_key] + res_list = [DataEntry(data=None)] + if os.path.isfile(vtt_file): + try: + data_entry[self.output_text_key] = get_vtt_text(vtt_file) + res_list = [DataEntry(data=data_entry)] + except Exception as e: + logger.warning("AllVttText " + vtt_file + " " + str(e)) + return res_list + + +class TxtToVtt(BaseParallelProcessor): + """ + A class for converting text files to WebVTT (VTT) format and updating the manifest. + + Args: + vtt_files_dir (str): The directory where the generated VTT files will be saved. + id_key (str): The field in the manifest representing the unique key or identifier for each entry. + text_field (str): The field in the manifest containing the text content to be converted to VTT format. + vtt_field (str): The field to store the generated VTT file paths in the manifest. + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + + Methods: + prepare(): Creates the directory for saving the generated VTT files. + process_dataset_entry(data_entry): Processes a single dataset entry, converts the text content to VTT format, and updates the manifest. + + """ + + def __init__( + self, + vtt_files_dir: str, + id_key: str, + text_key: str, + vtt_key: str, + **kwargs, + ): + super().__init__(**kwargs) + self.vtt_files_dir = vtt_files_dir + self.id_key = id_key + self.text_key = text_key + self.vtt_key = vtt_key + + self.trans_list = make_trans_list() + + def prepare(self): + os.makedirs(self.vtt_files_dir, exist_ok=True) + + def process_dataset_entry(self, data_entry): + key = data_entry[self.id_key] + text_file = data_entry[self.text_key] + os.makedirs(os.path.join(self.vtt_files_dir, key.split("/")[0]), exist_ok=True) + + vtt_file = os.path.join(self.vtt_files_dir, key) + ".vtt" + + txt2vtt(text_file, vtt_file, self.trans_list) + + data_entry[self.vtt_key] = vtt_file + + return [DataEntry(data=data_entry)] + + +class ReadParquet(BaseParallelProcessor): + """ + A class for reading information from Parquet files and updating the manifest with video URLs and captions. + + Args: + output_video_key (str): The field to store the extracted video URLs in the manifest. + output_caption_key (str): The field to store the extracted captions in the manifest. + id_key (str): The field in the manifest representing the unique key or identifier for each entry. + raw_data_dir (str): The directory containing Parquet files with information to be read. + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + + """ + + def __init__( + self, + output_video_key: str, + output_caption_key: str, + id_key: str, + raw_data_dir: str, + **kwargs, + ): + super().__init__(**kwargs) + self.output_video_key = output_video_key + self.output_caption_key = output_caption_key + self.id_key = id_key + self.raw_data_dir = Path(raw_data_dir) + + def prepare(self): + parquets = [str(self.raw_data_dir / p) for p in self.raw_data_dir.rglob('*.parquet')] + self.urls = None + for parquet in tqdm(parquets): + try: + df1 = pd.read_parquet(parquet, engine='fastparquet').sort_values("key").set_index("key") + if self.urls is None: + self.urls = df1 + else: + self.urls = pd.concat([self.urls, df1]) + except Exception as e: + logger.warning(str(e) + ", file: " + parquet) + + def process_dataset_entry(self, data_entry): + key = data_entry[self.id_key] + key = key.split("/")[1] + try: + data_entry[self.output_video_key] = self.urls.loc[key]['url'] + data_entry[self.output_caption_key] = self.urls.loc[key]['caption'] + except: + data_entry[self.output_video_key] = "NN" + data_entry[self.output_caption_key] = "NN" + logger.warning("Key without URL or caption: " + key) + return [DataEntry(data=data_entry)] + + +def get_key(x): + key = "/".join(os.path.splitext(x)[0].split("/")[-2:]) + return key + + +class CreateInitialManifestCC(BaseParallelProcessor): + """ + A class for creating an initial dataset manifest from image and text files with common keys. + + Args: + raw_data_dir (str): The directory containing image and text files to include in the initial dataset manifest. + video_key (str): The field to store the paths to the image files in the dataset. + id_key (str): The field to represent the common key or identifier for each entry. + text_key (str): The field to store the paths to the text files in the dataset. + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + + Methods: + prepare(): Creates the directory for saving the initial dataset manifest. + read_manifest(): Reads the image and text files, extracts common keys, and creates a DataFrame with video, key, and text fields. + process_dataset_entry(data_entry): Processes a single dataset entry, creating a DataEntry object with video, key, and text fields, and updates the dataset. + + """ + + def __init__( + self, + raw_data_dir: str, + video_key: str, + id_key: str, + text_key: str, + **kwargs, + ): + super().__init__(**kwargs) + self.raw_data_dir = Path(raw_data_dir) + self.video_key = video_key + self.id_key = id_key + self.text_key = text_key + + def prepare(self): + os.makedirs(self.raw_data_dir, exist_ok=True) + + def read_manifest(self): + videos = [str(self.raw_data_dir / video) for video in self.raw_data_dir.rglob('*.jpg')] + texts = [str(self.raw_data_dir / text) for text in self.raw_data_dir.rglob('*.txt')] + v_df = pd.DataFrame({self.video_key: videos}) + t_df = pd.DataFrame({self.text_key: texts}) + + v_df[self.id_key] = v_df[self.video_key].apply(get_key) + t_df[self.id_key] = t_df[self.text_key].apply(get_key) + v_df = v_df.drop_duplicates(self.id_key) + t_df = t_df.drop_duplicates(self.id_key) + vt_df = v_df.merge(t_df, on=self.id_key, how="left") + return vt_df.values + + def process_dataset_entry(self, data_entry): + (video, key, text) = data_entry + + data = {self.video_key: video, self.id_key: key, self.text_key: text} + return [DataEntry(data=data)] diff --git a/sdp/processors/datasets/commoncrawl/harv_utils.py b/sdp/processors/datasets/commoncrawl/harv_utils.py index baf115ec..41d591b0 100644 --- a/sdp/processors/datasets/commoncrawl/harv_utils.py +++ b/sdp/processors/datasets/commoncrawl/harv_utils.py @@ -1,18 +1,85 @@ +import json import os -import torch -# import ffmpeg # pip install ffmpeg-python -import webvtt # pip install webvtt-py -import subprocess, sys -import json, os -import soundfile as sf -from typing import Dict, List, Union +import subprocess +import sys from datetime import datetime -import numpy as np from pathlib import Path +from typing import Dict, List, Union + +import numpy as np import pandas as pd +import soundfile as sf +import torch +import webvtt # pip install webvtt-py + from sdp.logging import logger +def read_jsonl(manifest_file): + rec = [] + with open(manifest_file, 'r') as the_file: + for l in the_file: + rec.append(json.loads(l)) + return pd.DataFrame.from_records(rec) + + +def write_jsonl(df_in, manifest_filename): + with open(manifest_filename, 'w') as the_file: + for i, x in enumerate(df_in.itertuples()): + r_dict = {} + for column in df_in.columns: + r_dict[column] = getattr(x, column) + l1 = json.dumps(r_dict) + the_file.write(l1 + '\n') + + +def load_manifest(manifest: Path, keys: List[str] = []) -> List[Dict[str, Union[str, float]]]: + result = [] + r_dict = dict() + for key in keys: + r_dict[key] = list() + + with manifest.open() as f: + for i, line in enumerate(f): + data = json.loads(line) + result.append(data) + for key in keys: + r_dict[key].append(data[key]) + if keys: + return result, r_dict + else: + return result + + +def get_vtt_text(vtt_file): + text_all = [] + if os.path.splitext(vtt_file)[1] == '.vtt': + webvtt_i = webvtt.read + elif os.path.splitext(vtt_file)[1] == '.srt': + webvtt_i = webvtt.from_srt + else: + raise ValueError("Unsupported extention of file " + vtt_file) + + for caption in webvtt_i(vtt_file): + text = caption.text + if text.find("thumbnails") != -1: + pass + else: + text_all.append(' '.join(text.split('\n'))) + return ' '.join(text_all) + + +def text2lid(text_model, tokenizer, text): + text_langs = "Arabic, Basque, Breton, Catalan, Chinese_China, Chinese_Hongkong, Chinese_Taiwan, Chuvash, Czech, Dhivehi, Dutch, English, Esperanto, Estonian, French, Frisian, Georgian, German, Greek, Hakha_Chin, Indonesian, Interlingua, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Maltese, Mongolian, Persian, Polish, Portuguese, Romanian, Romansh_Sursilvan, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Ukranian, Welsh".split( + ", " + ) + inputs = tokenizer(text[:512], return_tensors="pt").to("cuda:0") + with torch.no_grad(): + text_logits = text_model(**inputs).logits + lang_id = text_logits.argmax(1).cpu()[0].numpy() + return text_langs[lang_id] + + def parse_hours(inp): inp_list = inp.split(":") if len(inp_list) == 3 and int(inp_list[0]) >= 24: @@ -30,11 +97,47 @@ def parse_hours(inp): return datetime.strptime(inp, '%H:%M:%S.%f') -def split_by_vtt(vtt_file, samplerate): +def split_by_vtt(vtt_file): try: _begin = datetime.strptime('00:00:00.000', '%H:%M:%S.%f') text_list, start_s, end_s = [], [], [] - for caption in webvtt.read(vtt_file): + if os.path.splitext(vtt_file)[1] == '.vtt': + webvtt_i = webvtt.read + elif os.path.splitext(vtt_file)[1] == '.srt': + webvtt_i = webvtt.from_srt + else: + raise ValueError("Unsupporte extention of file " + vtt_file) + + for caption in webvtt_i(vtt_file): + text = ' '.join(caption.text.split('\n')) + + _start = parse_hours(caption.start) + start = (_start - _begin).total_seconds() + + _end = parse_hours(caption.end) + end = (_end - _begin).total_seconds() + + text_list.append(text.strip()) + start_s.append(start) + end_s.append(end) + return text_list, start_s, end_s + except Exception as e: + logger.warning(str(e) + vtt_file) + return None, None, None + + +def split_by_vtt_new(vtt_file, samplerate): + try: + _begin = datetime.strptime('00:00:00.000', '%H:%M:%S.%f') + text_list, start_s, end_s = [], [], [] + if os.path.splitext(vtt_file)[1] == '.vtt': + webvtt_i = webvtt.read + elif os.path.splitext(vtt_file)[1] == '.srt': + webvtt_i = webvtt.from_srt + else: + raise ValueError("Unsupporte extention of file " + vtt_file) + + for caption in webvtt_i(vtt_file): text = ' '.join(caption.text.split('\n')) _start = parse_hours(caption.start) @@ -53,3 +156,691 @@ def split_by_vtt(vtt_file, samplerate): logger.warning(str(e) + vtt_file) return None, None, None + +def audio_duration(fname): + data, samplerate = sf.read(fname) + return data.shape[0] / samplerate + + +def ffmpeg_convert(jpg: str, wav: str, ar: int = 0, ac: int = 1): + process_args = ["ffmpeg", "-i", jpg, '-ac', str(ac), "-map", "0:a", "-c:a", "pcm_s16le", "-y", wav] + # '-filter_complex', '"[0:a]amerge=inputs=4[a]"', + if ar: + process_args = process_args[:-1] + process_args.extend(["-ar", str(ar), wav]) + return subprocess.run(process_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + +def read_txt(txt_file): + with open(txt_file, "r") as f: + text = f.read() + return text[2:-1].replace("\\n", "\n").replace("\\r", "\r") + + +def translate(txt, trans_list): + for trans in trans_list: + txt = txt.replace(trans[0], trans[1]) + return txt + + +def txt2vtt(txt_file: str, vtt_file: str, trans_list: List): + txt = read_txt(txt_file) + if txt: + if txt[:6] == "WEBVTT": + pass + else: + txt = "WEBVTT" + txt + # print(f"'{txt[:7]}''") + vtt = translate(txt, trans_list) + with open(vtt_file, "w") as f: + f.write(vtt) + + +def make_trans_list(): + t1 = """U+0000   + U+0001 \' \\' + U+0080 \\xc2\\x80 + U+0081 \\xc2\\x81 + U+0082 \\xc2\\x82 + U+0083 \\xc2\\x83 + U+0084 \\xc2\\x84 + U+0085 \\xc2\\x85 + U+0086 \\xc2\\x86 + U+0087 \\xc2\\x87 + U+0088 \\xc2\\x88 + U+0089 \\xc2\\x89 + U+008A \\xc2\\x8a + U+008B \\xc2\\x8b + U+008C \\xc2\\x8c + U+008D \\xc2\\x8d + U+008E \\xc2\\x8e + U+008F \\xc2\\x8f + U+0090 \\xc2\\x90 + U+0091 \\xc2\\x91 + U+0092 \\xc2\\x92 + U+0093 \\xc2\\x93 + U+0094 \\xc2\\x94 + U+0095 \\xc2\\x95 + U+0096 \\xc2\\x96 + U+0097 \\xc2\\x97 + U+0098 \\xc2\\x98 + U+0099 \\xc2\\x99 + U+009A \\xc2\\x9a + U+009B \\xc2\\x9b + U+009C \\xc2\\x9c + U+009D \\xc2\\x9d + U+009E \\xc2\\x9e + U+009F \\xc2\\x9f + U+00A0 \\xc2\\xa0 + U+00A1 ¡ \\xc2\\xa1 + U+00A2 ¢ \\xc2\\xa2 + U+00A3 £ \\xc2\\xa3 + U+00A4 ¤ \\xc2\\xa4 + U+00A5 ¥ \\xc2\\xa5 + U+00A6 ¦ \\xc2\\xa6 + U+00A7 § \\xc2\\xa7 + U+00A8 ¨ \\xc2\\xa8 + U+00A9 © \\xc2\\xa9 + U+00AA ª \\xc2\\xaa + U+00AB « \\xc2\\xab + U+00AC ¬ \\xc2\\xac + U+00AD ­ \\xc2\\xad + U+00AE ® \\xc2\\xae + U+00AF ¯ \\xc2\\xaf + U+00B0 ° \\xc2\\xb0 + U+00B1 ± \\xc2\\xb1 + U+00B2 ² \\xc2\\xb2 + U+00B3 ³ \\xc2\\xb3 + U+00B4 ´ \\xc2\\xb4 + U+00B5 µ \\xc2\\xb5 + U+00B6 ¶ \\xc2\\xb6 + U+00B7 · \\xc2\\xb7 + U+00B8 ¸ \\xc2\\xb8 + U+00B9 ¹ \\xc2\\xb9 + U+00BA º \\xc2\\xba + U+00BB » \\xc2\\xbb + U+00BC ¼ \\xc2\\xbc + U+00BD ½ \\xc2\\xbd + U+00BE ¾ \\xc2\\xbe + U+00BF ¿ \\xc2\\xbf + U+00C0 À \\xc3\\x80 + U+00C1 Á \\xc3\\x81 + U+00C2  \\xc3\\x82 + U+00C3 à \\xc3\\x83 + U+00C4 Ä \\xc3\\x84 + U+00C5 Å \\xc3\\x85 + U+00C6 Æ \\xc3\\x86 + U+00C7 Ç \\xc3\\x87 + U+00C8 È \\xc3\\x88 + U+00C9 É \\xc3\\x89 + U+00CA Ê \\xc3\\x8a + U+00CB Ë \\xc3\\x8b + U+00CC Ì \\xc3\\x8c + U+00CD Í \\xc3\\x8d + U+00CE Î \\xc3\\x8e + U+00CF Ï \\xc3\\x8f + U+00D0 Ð \\xc3\\x90 + U+00D1 Ñ \\xc3\\x91 + U+00D2 Ò \\xc3\\x92 + U+00D3 Ó \\xc3\\x93 + U+00D4 Ô \\xc3\\x94 + U+00D5 Õ \\xc3\\x95 + U+00D6 Ö \\xc3\\x96 + U+00D7 × \\xc3\\x97 + U+00D8 Ø \\xc3\\x98 + U+00D9 Ù \\xc3\\x99 + U+00DA Ú \\xc3\\x9a + U+00DB Û \\xc3\\x9b + U+00DC Ü \\xc3\\x9c + U+00DD Ý \\xc3\\x9d + U+00DE Þ \\xc3\\x9e + U+00DF ß \\xc3\\x9f + U+00E0 à \\xc3\\xa0 + U+00E1 á \\xc3\\xa1 + U+00E2 â \\xc3\\xa2 + U+00E3 ã \\xc3\\xa3 + U+00E4 ä \\xc3\\xa4 + U+00E5 å \\xc3\\xa5 + U+00E6 æ \\xc3\\xa6 + U+00E7 ç \\xc3\\xa7 + U+00E8 è \\xc3\\xa8 + U+00E9 é \\xc3\\xa9 + U+00EA ê \\xc3\\xaa + U+00EB ë \\xc3\\xab + U+00EC ì \\xc3\\xac + U+00ED í \\xc3\\xad + U+00EE î \\xc3\\xae + U+00EF ï \\xc3\\xaf + U+00F0 ð \\xc3\\xb0 + U+00F1 ñ \\xc3\\xb1 + U+00F2 ò \\xc3\\xb2 + U+00F3 ó \\xc3\\xb3 + U+00F4 ô \\xc3\\xb4 + U+00F5 õ \\xc3\\xb5 + U+00F6 ö \\xc3\\xb6 + U+00F7 ÷ \\xc3\\xb7 + U+00F8 ø \\xc3\\xb8 + U+00F9 ù \\xc3\\xb9 + U+00FA ú \\xc3\\xba + U+00FB û \\xc3\\xbb + U+00FC ü \\xc3\\xbc + U+00FD ý \\xc3\\xbd + U+00FE þ \\xc3\\xbe + U+00FF ÿ \\xc3\\xbf + U+0100 Ā \\xc4\\x80 + U+0101 ā \\xc4\\x81 + U+0102 Ă \\xc4\\x82 + U+0103 ă \\xc4\\x83 + U+0104 Ą \\xc4\\x84 + U+0105 ą \\xc4\\x85 + U+0106 Ć \\xc4\\x86 + U+0107 ć \\xc4\\x87 + U+0108 Ĉ \\xc4\\x88 + U+0109 ĉ \\xc4\\x89 + U+010A Ċ \\xc4\\x8a + U+010B ċ \\xc4\\x8b + U+010C Č \\xc4\\x8c + U+010D č \\xc4\\x8d + U+010E Ď \\xc4\\x8e + U+010F ď \\xc4\\x8f + U+0110 Đ \\xc4\\x90 + U+0111 đ \\xc4\\x91 + U+0112 Ē \\xc4\\x92 + U+0113 ē \\xc4\\x93 + U+0114 Ĕ \\xc4\\x94 + U+0115 ĕ \\xc4\\x95 + U+0116 Ė \\xc4\\x96 + U+0117 ė \\xc4\\x97 + U+0118 Ę \\xc4\\x98 + U+0119 ę \\xc4\\x99 + U+011A Ě \\xc4\\x9a + U+011B ě \\xc4\\x9b + U+011C Ĝ \\xc4\\x9c + U+011D ĝ \\xc4\\x9d + U+011E Ğ \\xc4\\x9e + U+011F ğ \\xc4\\x9f + U+0120 Ġ \\xc4\\xa0 + U+0121 ġ \\xc4\\xa1 + U+0122 Ģ \\xc4\\xa2 + U+0123 ģ \\xc4\\xa3 + U+0124 Ĥ \\xc4\\xa4 + U+0125 ĥ \\xc4\\xa5 + U+0126 Ħ \\xc4\\xa6 + U+0127 ħ \\xc4\\xa7 + U+0128 Ĩ \\xc4\\xa8 + U+0129 ĩ \\xc4\\xa9 + U+012A Ī \\xc4\\xaa + U+012B ī \\xc4\\xab + U+012C Ĭ \\xc4\\xac + U+012D ĭ \\xc4\\xad + U+012E Į \\xc4\\xae + U+012F į \\xc4\\xaf + U+0130 İ \\xc4\\xb0 + U+0131 ı \\xc4\\xb1 + U+0132 IJ \\xc4\\xb2 + U+0133 ij \\xc4\\xb3 + U+0134 Ĵ \\xc4\\xb4 + U+0135 ĵ \\xc4\\xb5 + U+0136 Ķ \\xc4\\xb6 + U+0137 ķ \\xc4\\xb7 + U+0138 ĸ \\xc4\\xb8 + U+0139 Ĺ \\xc4\\xb9 + U+013A ĺ \\xc4\\xba + U+013B Ļ \\xc4\\xbb + U+013C ļ \\xc4\\xbc + U+013D Ľ \\xc4\\xbd + U+013E ľ \\xc4\\xbe + U+013F Ŀ \\xc4\\xbf + U+0140 ŀ \\xc5\\x80 + U+0141 Ł \\xc5\\x81 + U+0142 ł \\xc5\\x82 + U+0143 Ń \\xc5\\x83 + U+0144 ń \\xc5\\x84 + U+0145 Ņ \\xc5\\x85 + U+0146 ņ \\xc5\\x86 + U+0147 Ň \\xc5\\x87 + U+0148 ň \\xc5\\x88 + U+0149 ʼn \\xc5\\x89 + U+014A Ŋ \\xc5\\x8a + U+014B ŋ \\xc5\\x8b + U+014C Ō \\xc5\\x8c + U+014D ō \\xc5\\x8d + U+014E Ŏ \\xc5\\x8e + U+014F ŏ \\xc5\\x8f + U+0150 Ő \\xc5\\x90 + U+0151 ő \\xc5\\x91 + U+0152 Œ \\xc5\\x92 + U+0153 œ \\xc5\\x93 + U+0154 Ŕ \\xc5\\x94 + U+0155 ŕ \\xc5\\x95 + U+0156 Ŗ \\xc5\\x96 + U+0157 ŗ \\xc5\\x97 + U+0158 Ř \\xc5\\x98 + U+0159 ř \\xc5\\x99 + U+015A Ś \\xc5\\x9a + U+015B ś \\xc5\\x9b + U+015C Ŝ \\xc5\\x9c + U+015D ŝ \\xc5\\x9d + U+015E Ş \\xc5\\x9e + U+015F ş \\xc5\\x9f + U+0160 Š \\xc5\\xa0 + U+0161 š \\xc5\\xa1 + U+0162 Ţ \\xc5\\xa2 + U+0163 ţ \\xc5\\xa3 + U+0164 Ť \\xc5\\xa4 + U+0165 ť \\xc5\\xa5 + U+0166 Ŧ \\xc5\\xa6 + U+0167 ŧ \\xc5\\xa7 + U+0168 Ũ \\xc5\\xa8 + U+0169 ũ \\xc5\\xa9 + U+016A Ū \\xc5\\xaa + U+016B ū \\xc5\\xab + U+016C Ŭ \\xc5\\xac + U+016D ŭ \\xc5\\xad + U+016E Ů \\xc5\\xae + U+016F ů \\xc5\\xaf + U+0170 Ű \\xc5\\xb0 + U+0171 ű \\xc5\\xb1 + U+0172 Ų \\xc5\\xb2 + U+0173 ų \\xc5\\xb3 + U+0174 Ŵ \\xc5\\xb4 + U+0175 ŵ \\xc5\\xb5 + U+0176 Ŷ \\xc5\\xb6 + U+0177 ŷ \\xc5\\xb7 + U+0178 Ÿ \\xc5\\xb8 + U+0179 Ź \\xc5\\xb9 + U+017A ź \\xc5\\xba + U+017B Ż \\xc5\\xbb + U+017C ż \\xc5\\xbc + U+017D Ž \\xc5\\xbd + U+017E ž \\xc5\\xbe + U+017F ſ \\xc5\\xbf + U+0180 ƀ \\xc6\\x80 + U+0181 Ɓ \\xc6\\x81 + U+0182 Ƃ \\xc6\\x82 + U+0183 ƃ \\xc6\\x83 + U+0184 Ƅ \\xc6\\x84 + U+0185 ƅ \\xc6\\x85 + U+0186 Ɔ \\xc6\\x86 + U+0187 Ƈ \\xc6\\x87 + U+0188 ƈ \\xc6\\x88 + U+0189 Ɖ \\xc6\\x89 + U+018A Ɗ \\xc6\\x8a + U+018B Ƌ \\xc6\\x8b + U+018C ƌ \\xc6\\x8c + U+018D ƍ \\xc6\\x8d + U+018E Ǝ \\xc6\\x8e + U+018F Ə \\xc6\\x8f + U+0190 Ɛ \\xc6\\x90 + U+0191 Ƒ \\xc6\\x91 + U+0192 ƒ \\xc6\\x92 + U+0193 Ɠ \\xc6\\x93 + U+0194 Ɣ \\xc6\\x94 + U+0195 ƕ \\xc6\\x95 + U+0196 Ɩ \\xc6\\x96 + U+0197 Ɨ \\xc6\\x97 + U+0198 Ƙ \\xc6\\x98 + U+0199 ƙ \\xc6\\x99 + U+019A ƚ \\xc6\\x9a + U+019B ƛ \\xc6\\x9b + U+019C Ɯ \\xc6\\x9c + U+019D Ɲ \\xc6\\x9d + U+019E ƞ \\xc6\\x9e + U+019F Ɵ \\xc6\\x9f + U+01A0 Ơ \\xc6\\xa0 + U+01A1 ơ \\xc6\\xa1 + U+01A2 Ƣ \\xc6\\xa2 + U+01A3 ƣ \\xc6\\xa3 + U+01A4 Ƥ \\xc6\\xa4 + U+01A5 ƥ \\xc6\\xa5 + U+01A6 Ʀ \\xc6\\xa6 + U+01A7 Ƨ \\xc6\\xa7 + U+01A8 ƨ \\xc6\\xa8 + U+01A9 Ʃ \\xc6\\xa9 + U+01AA ƪ \\xc6\\xaa + U+01AB ƫ \\xc6\\xab + U+01AC Ƭ \\xc6\\xac + U+01AD ƭ \\xc6\\xad + U+01AE Ʈ \\xc6\\xae + U+01AF Ư \\xc6\\xaf + U+01B0 ư \\xc6\\xb0 + U+01B1 Ʊ \\xc6\\xb1 + U+01B2 Ʋ \\xc6\\xb2 + U+01B3 Ƴ \\xc6\\xb3 + U+01B4 ƴ \\xc6\\xb4 + U+01B5 Ƶ \\xc6\\xb5 + U+01B6 ƶ \\xc6\\xb6 + U+01B7 Ʒ \\xc6\\xb7 + U+01B8 Ƹ \\xc6\\xb8 + U+01B9 ƹ \\xc6\\xb9 + U+01BA ƺ \\xc6\\xba + U+01BB ƻ \\xc6\\xbb + U+01BC Ƽ \\xc6\\xbc + U+01BD ƽ \\xc6\\xbd + U+01BE ƾ \\xc6\\xbe + U+01BF ƿ \\xc6\\xbf + U+01C0 ǀ \\xc7\\x80 + U+01C1 ǁ \\xc7\\x81 + U+01C2 ǂ \\xc7\\x82 + U+01C3 ǃ \\xc7\\x83 + U+01C4 DŽ \\xc7\\x84 + U+01C5 Dž \\xc7\\x85 + U+01C6 dž \\xc7\\x86 + U+01C7 LJ \\xc7\\x87 + U+01C8 Lj \\xc7\\x88 + U+01C9 lj \\xc7\\x89 + U+01CA NJ \\xc7\\x8a + U+01CB Nj \\xc7\\x8b + U+01CC nj \\xc7\\x8c + U+01CD Ǎ \\xc7\\x8d + U+01CE ǎ \\xc7\\x8e + U+01CF Ǐ \\xc7\\x8f + U+01D0 ǐ \\xc7\\x90 + U+01D1 Ǒ \\xc7\\x91 + U+01D2 ǒ \\xc7\\x92 + U+01D3 Ǔ \\xc7\\x93 + U+01D4 ǔ \\xc7\\x94 + U+01D5 Ǖ \\xc7\\x95 + U+01D6 ǖ \\xc7\\x96 + U+01D7 Ǘ \\xc7\\x97 + U+01D8 ǘ \\xc7\\x98 + U+01D9 Ǚ \\xc7\\x99 + U+01DA ǚ \\xc7\\x9a + U+01DB Ǜ \\xc7\\x9b + U+01DC ǜ \\xc7\\x9c + U+01DD ǝ \\xc7\\x9d + U+01DE Ǟ \\xc7\\x9e + U+01DF ǟ \\xc7\\x9f + U+01E0 Ǡ \\xc7\\xa0 + U+01E1 ǡ \\xc7\\xa1 + U+01E2 Ǣ \\xc7\\xa2 + U+01E3 ǣ \\xc7\\xa3 + U+01E4 Ǥ \\xc7\\xa4 + U+01E5 ǥ \\xc7\\xa5 + U+01E6 Ǧ \\xc7\\xa6 + U+01E7 ǧ \\xc7\\xa7 + U+01E8 Ǩ \\xc7\\xa8 + U+01E9 ǩ \\xc7\\xa9 + U+01EA Ǫ \\xc7\\xaa + U+01EB ǫ \\xc7\\xab + U+01EC Ǭ \\xc7\\xac + U+01ED ǭ \\xc7\\xad + U+01EE Ǯ \\xc7\\xae + U+01EF ǯ \\xc7\\xaf + U+01F0 ǰ \\xc7\\xb0 + U+01F1 DZ \\xc7\\xb1 + U+01F2 Dz \\xc7\\xb2 + U+01F3 dz \\xc7\\xb3 + U+01F4 Ǵ \\xc7\\xb4 + U+01F5 ǵ \\xc7\\xb5 + U+01F6 Ƕ \\xc7\\xb6 + U+01F7 Ƿ \\xc7\\xb7 + U+01F8 Ǹ \\xc7\\xb8 + U+01F9 ǹ \\xc7\\xb9 + U+01FA Ǻ \\xc7\\xba + U+01FB ǻ \\xc7\\xbb + U+01FC Ǽ \\xc7\\xbc + U+01FD ǽ \\xc7\\xbd + U+01FE Ǿ \\xc7\\xbe + U+01FF ǿ \\xc7\\xbf + U+2000   \\xe2\\x80\\x80 EN QUAD + U+2001   \\xe2\\x80\\x81 EM QUAD + U+2002   \\xe2\\x80\\x82 EN SPACE + U+2003   \\xe2\\x80\\x83 EM SPACE + U+2004   \\xe2\\x80\\x84 THREE-PER-EM SPACE + U+2005   \\xe2\\x80\\x85 FOUR-PER-EM SPACE + U+2006   \\xe2\\x80\\x86 SIX-PER-EM SPACE + U+2007   \\xe2\\x80\\x87 FIGURE SPACE + U+2008   \\xe2\\x80\\x88 PUNCTUATION SPACE + U+2009   \\xe2\\x80\\x89 THIN SPACE + U+200A   \\xe2\\x80\\x8a HAIR SPACE + U+200B ​ \\xe2\\x80\\x8b ZERO WIDTH SPACE + U+200C ‌ \\xe2\\x80\\x8c ZERO WIDTH NON-JOINER + U+200D ‍ \\xe2\\x80\\x8d ZERO WIDTH JOINER + U+200E ‎ \\xe2\\x80\\x8e LEFT-TO-RIGHT MARK + U+200F ‏ \\xe2\\x80\\x8f RIGHT-TO-LEFT MARK + U+2010 ‐ \\xe2\\x80\\x90 HYPHEN + U+2011 ‑ \\xe2\\x80\\x91 NON-BREAKING HYPHEN + U+2012 ‒ \\xe2\\x80\\x92 FIGURE DASH + U+2013 – \\xe2\\x80\\x93 EN DASH + U+2014 — \\xe2\\x80\\x94 EM DASH + U+2015 ― \\xe2\\x80\\x95 HORIZONTAL BAR + U+2016 ‖ \\xe2\\x80\\x96 DOUBLE VERTICAL LINE + U+2017 ‗ \\xe2\\x80\\x97 DOUBLE LOW LINE + U+2018 ‘ \\xe2\\x80\\x98 LEFT SINGLE QUOTATION MARK + U+2019 ’ \\xe2\\x80\\x99 RIGHT SINGLE QUOTATION MARK + U+201A ‚ \\xe2\\x80\\x9a SINGLE LOW-9 QUOTATION MARK + U+201B ‛ \\xe2\\x80\\x9b SINGLE HIGH-REVERSED-9 QUOTATION MARK + U+201C “ \\xe2\\x80\\x9c LEFT DOUBLE QUOTATION MARK + U+201D ” \\xe2\\x80\\x9d RIGHT DOUBLE QUOTATION MARK + U+201E „ \\xe2\\x80\\x9e DOUBLE LOW-9 QUOTATION MARK + U+201F ‟ \\xe2\\x80\\x9f DOUBLE HIGH-REVERSED-9 QUOTATION MARK + U+2020 † \\xe2\\x80\\xa0 DAGGER + U+2021 ‡ \\xe2\\x80\\xa1 DOUBLE DAGGER + U+2022 • \\xe2\\x80\\xa2 BULLET + U+2023 ‣ \\xe2\\x80\\xa3 TRIANGULAR BULLET + U+2024 ․ \\xe2\\x80\\xa4 ONE DOT LEADER + U+2025 ‥ \\xe2\\x80\\xa5 TWO DOT LEADER + U+2026 … \\xe2\\x80\\xa6 HORIZONTAL ELLIPSIS + U+2027 ‧ \\xe2\\x80\\xa7 HYPHENATION POINT + U+2028 \\xe2\\x80\\xa8 LINE SEPARATOR + U+2029 \\xe2\\x80\\xa9 PARAGRAPH SEPARATOR + U+202A ‪ \\xe2\\x80\\xaa LEFT-TO-RIGHT EMBEDDING + U+202B ‫ \\xe2\\x80\\xab RIGHT-TO-LEFT EMBEDDING + U+202C ‬ \\xe2\\x80\\xac POP DIRECTIONAL FORMATTING + U+202D ‭ \\xe2\\x80\\xad LEFT-TO-RIGHT OVERRIDE + U+202E ‮ \\xe2\\x80\\xae RIGHT-TO-LEFT OVERRIDE + U+202F   \\xe2\\x80\\xaf NARROW NO-BREAK SPACE + U+2030 ‰ \\xe2\\x80\\xb0 PER MILLE SIGN + U+2031 ‱ \\xe2\\x80\\xb1 PER TEN THOUSAND SIGN + U+2032 ′ \\xe2\\x80\\xb2 PRIME + U+2033 ″ \\xe2\\x80\\xb3 DOUBLE PRIME + U+2034 ‴ \\xe2\\x80\\xb4 TRIPLE PRIME + U+2035 ‵ \\xe2\\x80\\xb5 REVERSED PRIME + U+2036 ‶ \\xe2\\x80\\xb6 REVERSED DOUBLE PRIME + U+2037 ‷ \\xe2\\x80\\xb7 REVERSED TRIPLE PRIME + U+2038 ‸ \\xe2\\x80\\xb8 CARET + U+2039 ‹ \\xe2\\x80\\xb9 SINGLE LEFT-POINTING ANGLE QUOTATION MARK + U+203A › \\xe2\\x80\\xba SINGLE RIGHT-POINTING ANGLE QUOTATION MARK + U+203B ※ \\xe2\\x80\\xbb REFERENCE MARK + U+203C ‼ \\xe2\\x80\\xbc DOUBLE EXCLAMATION MARK + U+203D ‽ \\xe2\\x80\\xbd INTERROBANG + U+203E ‾ \\xe2\\x80\\xbe OVERLINE + U+203F ‿ \\xe2\\x80\\xbf UNDERTIE + U+2040 ⁀ \\xe2\\x81\\x80 CHARACTER TIE + U+2041 ⁁ \\xe2\\x81\\x81 CARET INSERTION POINT + U+2042 ⁂ \\xe2\\x81\\x82 ASTERISM + U+2043 ⁃ \\xe2\\x81\\x83 HYPHEN BULLET + U+2044 ⁄ \\xe2\\x81\\x84 FRACTION SLASH + U+2045 ⁅ \\xe2\\x81\\x85 LEFT SQUARE BRACKET WITH QUILL + U+2046 ⁆ \\xe2\\x81\\x86 RIGHT SQUARE BRACKET WITH QUILL + U+2047 ⁇ \\xe2\\x81\\x87 DOUBLE QUESTION MARK + U+2048 ⁈ \\xe2\\x81\\x88 QUESTION EXCLAMATION MARK + U+2049 ⁉ \\xe2\\x81\\x89 EXCLAMATION QUESTION MARK + U+204A ⁊ \\xe2\\x81\\x8a TIRONIAN SIGN ET + U+204B ⁋ \\xe2\\x81\\x8b REVERSED PILCROW SIGN + U+204C ⁌ \\xe2\\x81\\x8c BLACK LEFTWARDS BULLET + U+204D ⁍ \\xe2\\x81\\x8d BLACK RIGHTWARDS BULLET + U+204E ⁎ \\xe2\\x81\\x8e LOW ASTERISK + U+204F ⁏ \\xe2\\x81\\x8f REVERSED SEMICOLON + U+2050 ⁐ \\xe2\\x81\\x90 CLOSE UP + U+2051 ⁑ \\xe2\\x81\\x91 TWO ASTERISKS ALIGNED VERTICALLY + U+2052 ⁒ \\xe2\\x81\\x92 COMMERCIAL MINUS SIGN + U+2053 ⁓ \\xe2\\x81\\x93 SWUNG DASH + U+2054 ⁔ \\xe2\\x81\\x94 INVERTED UNDERTIE + U+2055 ⁕ \\xe2\\x81\\x95 FLOWER PUNCTUATION MARK + U+2056 ⁖ \\xe2\\x81\\x96 THREE DOT PUNCTUATION + U+2057 ⁗ \\xe2\\x81\\x97 QUADRUPLE PRIME + U+2058 ⁘ \\xe2\\x81\\x98 FOUR DOT PUNCTUATION + U+2059 ⁙ \\xe2\\x81\\x99 FIVE DOT PUNCTUATION + U+205A ⁚ \\xe2\\x81\\x9a TWO DOT PUNCTUATION + U+205B ⁛ \\xe2\\x81\\x9b FOUR DOT MARK + U+205C ⁜ \\xe2\\x81\\x9c DOTTED CROSS + U+205D ⁝ \\xe2\\x81\\x9d TRICOLON + U+205E ⁞ \\xe2\\x81\\x9e VERTICAL FOUR DOTS + U+205F   \\xe2\\x81\\x9f MEDIUM MATHEMATICAL SPACE + U+2060 ⁠ \\xe2\\x81\\xa0 WORD JOINER + U+2061 ⁡ \\xe2\\x81\\xa1 FUNCTION APPLICATION + U+2062 ⁢ \\xe2\\x81\\xa2 INVISIBLE TIMES + U+2063 ⁣ \\xe2\\x81\\xa3 INVISIBLE SEPARATOR + U+2064 ⁤ \\xe2\\x81\\xa4 INVISIBLE PLUS + U+2065 ⁥ \\xe2\\x81\\xa5 + U+2066 ⁦ \\xe2\\x81\\xa6 LEFT-TO-RIGHT ISOLATE + U+2067 ⁧ \\xe2\\x81\\xa7 RIGHT-TO-LEFT ISOLATE + U+2068 ⁨ \\xe2\\x81\\xa8 FIRST STRONG ISOLATE + U+2069 ⁩ \\xe2\\x81\\xa9 POP DIRECTIONAL ISOLATE + U+206A  \\xe2\\x81\\xaa INHIBIT SYMMETRIC SWAPPING + U+206B  \\xe2\\x81\\xab ACTIVATE SYMMETRIC SWAPPING + U+206C  \\xe2\\x81\\xac INHIBIT ARABIC FORM SHAPING + U+206D  \\xe2\\x81\\xad ACTIVATE ARABIC FORM SHAPING + U+206E  \\xe2\\x81\\xae NATIONAL DIGIT SHAPES + U+206F  \\xe2\\x81\\xaf NOMINAL DIGIT SHAPES + U+2070 ⁰ \\xe2\\x81\\xb0 SUPERSCRIPT ZERO + U+2071 ⁱ \\xe2\\x81\\xb1 SUPERSCRIPT LATIN SMALL LETTER I + U+2072 ⁲ \\xe2\\x81\\xb2 + U+2073 ⁳ \\xe2\\x81\\xb3 + U+2074 ⁴ \\xe2\\x81\\xb4 SUPERSCRIPT FOUR + U+2075 ⁵ \\xe2\\x81\\xb5 SUPERSCRIPT FIVE + U+2076 ⁶ \\xe2\\x81\\xb6 SUPERSCRIPT SIX + U+2077 ⁷ \\xe2\\x81\\xb7 SUPERSCRIPT SEVEN + U+2078 ⁸ \\xe2\\x81\\xb8 SUPERSCRIPT EIGHT + U+2079 ⁹ \\xe2\\x81\\xb9 SUPERSCRIPT NINE + U+207A ⁺ \\xe2\\x81\\xba SUPERSCRIPT PLUS SIGN + U+207B ⁻ \\xe2\\x81\\xbb SUPERSCRIPT MINUS + U+207C ⁼ \\xe2\\x81\\xbc SUPERSCRIPT EQUALS SIGN + U+207D ⁽ \\xe2\\x81\\xbd SUPERSCRIPT LEFT PARENTHESIS + U+207E ⁾ \\xe2\\x81\\xbe SUPERSCRIPT RIGHT PARENTHESIS + U+207F ⁿ \\xe2\\x81\\xbf SUPERSCRIPT LATIN SMALL LETTER N + U+2580 ▀ \\xe2\\x96\\x80 + U+2581 ▁ \\xe2\\x96\\x81 + U+2582 ▂ \\xe2\\x96\\x82 + U+2583 ▃ \\xe2\\x96\\x83 + U+2584 ▄ \\xe2\\x96\\x84 + U+2585 ▅ \\xe2\\x96\\x85 + U+2586 ▆ \\xe2\\x96\\x86 + U+2587 ▇ \\xe2\\x96\\x87 + U+2588 █ \\xe2\\x96\\x88 + U+2589 ▉ \\xe2\\x96\\x89 + U+258A ▊ \\xe2\\x96\\x8a + U+258B ▋ \\xe2\\x96\\x8b + U+258C ▌ \\xe2\\x96\\x8c + U+258D ▍ \\xe2\\x96\\x8d + U+258E ▎ \\xe2\\x96\\x8e + U+258F ▏ \\xe2\\x96\\x8f + U+2590 ▐ \\xe2\\x96\\x90 + U+2591 ░ \\xe2\\x96\\x91 + U+2592 ▒ \\xe2\\x96\\x92 + U+2593 ▓ \\xe2\\x96\\x93 + U+2594 ▔ \\xe2\\x96\\x94 + U+2595 ▕ \\xe2\\x96\\x95 + U+2596 ▖ \\xe2\\x96\\x96 + U+2597 ▗ \\xe2\\x96\\x97 + U+2598 ▘ \\xe2\\x96\\x98 + U+2599 ▙ \\xe2\\x96\\x99 + U+259A ▚ \\xe2\\x96\\x9a + U+259B ▛ \\xe2\\x96\\x9b + U+259C ▜ \\xe2\\x96\\x9c + U+259D ▝ \\xe2\\x96\\x9d + U+259E ▞ \\xe2\\x96\\x9e + U+259F ▟ \\xe2\\x96\\x9f + U+25A0 ■ \\xe2\\x96\\xa0 + U+25A1 □ \\xe2\\x96\\xa1 + U+25A2 ▢ \\xe2\\x96\\xa2 + U+25A3 ▣ \\xe2\\x96\\xa3 + U+25A4 ▤ \\xe2\\x96\\xa4 + U+25A5 ▥ \\xe2\\x96\\xa5 + U+25A6 ▦ \\xe2\\x96\\xa6 + U+25A7 ▧ \\xe2\\x96\\xa7 + U+25A8 ▨ \\xe2\\x96\\xa8 + U+25A9 ▩ \\xe2\\x96\\xa9 + U+25AA ▪ \\xe2\\x96\\xaa + U+25AB ▫ \\xe2\\x96\\xab + U+25AC ▬ \\xe2\\x96\\xac + U+25AD ▭ \\xe2\\x96\\xad + U+25AE ▮ \\xe2\\x96\\xae + U+25AF ▯ \\xe2\\x96\\xaf + U+25B0 ▰ \\xe2\\x96\\xb0 + U+25B1 ▱ \\xe2\\x96\\xb1 + U+25B2 ▲ \\xe2\\x96\\xb2 + U+25B3 △ \\xe2\\x96\\xb3 + U+25B4 ▴ \\xe2\\x96\\xb4 + U+25B5 ▵ \\xe2\\x96\\xb5 + U+25B6 ▶ \\xe2\\x96\\xb6 + U+25B7 ▷ \\xe2\\x96\\xb7 + U+25B8 ▸ \\xe2\\x96\\xb8 + U+25B9 ▹ \\xe2\\x96\\xb9 + U+25BA ► \\xe2\\x96\\xba + U+25BB ▻ \\xe2\\x96\\xbb + U+25BC ▼ \\xe2\\x96\\xbc + U+25BD ▽ \\xe2\\x96\\xbd + U+25BE ▾ \\xe2\\x96\\xbe + U+25BF ▿ \\xe2\\x96\\xbf + U+25C0 ◀ \\xe2\\x97\\x80 + U+25C1 ◁ \\xe2\\x97\\x81 + U+25C2 ◂ \\xe2\\x97\\x82 + U+25C3 ◃ \\xe2\\x97\\x83 + U+25C4 ◄ \\xe2\\x97\\x84 + U+25C5 ◅ \\xe2\\x97\\x85 + U+25C6 ◆ \\xe2\\x97\\x86 + U+25C7 ◇ \\xe2\\x97\\x87 + U+25C8 ◈ \\xe2\\x97\\x88 + U+25C9 ◉ \\xe2\\x97\\x89 + U+25CA ◊ \\xe2\\x97\\x8a + U+25CB ○ \\xe2\\x97\\x8b + U+25CC ◌ \\xe2\\x97\\x8c + U+25CD ◍ \\xe2\\x97\\x8d + U+25CE ◎ \\xe2\\x97\\x8e + U+25CF ● \\xe2\\x97\\x8f + U+25D0 ◐ \\xe2\\x97\\x90 + U+25D1 ◑ \\xe2\\x97\\x91 + U+25D2 ◒ \\xe2\\x97\\x92 + U+25D3 ◓ \\xe2\\x97\\x93 + U+25D4 ◔ \\xe2\\x97\\x94 + U+25D5 ◕ \\xe2\\x97\\x95 + U+25D6 ◖ \\xe2\\x97\\x96 + U+25D7 ◗ \\xe2\\x97\\x97 + U+25D8 ◘ \\xe2\\x97\\x98 + U+25D9 ◙ \\xe2\\x97\\x99 + U+25DA ◚ \\xe2\\x97\\x9a + U+25DB ◛ \\xe2\\x97\\x9b + U+25DC ◜ \\xe2\\x97\\x9c + U+25DD ◝ \\xe2\\x97\\x9d + U+25DE ◞ \\xe2\\x97\\x9e + U+25DF ◟ \\xe2\\x97\\x9f + U+25E0 ◠ \\xe2\\x97\\xa0 + U+25E1 ◡ \\xe2\\x97\\xa1 + U+25E2 ◢ \\xe2\\x97\\xa2 + U+25E3 ◣ \\xe2\\x97\\xa3 + U+25E4 ◤ \\xe2\\x97\\xa4 + U+25E5 ◥ \\xe2\\x97\\xa5 + U+25E6 ◦ \\xe2\\x97\\xa6 + U+25E7 ◧ \\xe2\\x97\\xa7 + U+25E8 ◨ \\xe2\\x97\\xa8 + U+25E9 ◩ \\xe2\\x97\\xa9 + U+25EA ◪ \\xe2\\x97\\xaa + U+25EB ◫ \\xe2\\x97\\xab + U+25EC ◬ \\xe2\\x97\\xac + U+25ED ◭ \\xe2\\x97\\xad + U+25EE ◮ \\xe2\\x97\\xae + U+25EF ◯ \\xe2\\x97\\xaf + U+25F0 ◰ \\xe2\\x97\\xb0 + U+25F1 ◱ \\xe2\\x97\\xb1 + U+25F2 ◲ \\xe2\\x97\\xb2 + U+25F3 ◳ \\xe2\\x97\\xb3 + U+25F4 ◴ \\xe2\\x97\\xb4 + U+25F5 ◵ \\xe2\\x97\\xb5 + U+25F6 ◶ \\xe2\\x97\\xb6 + U+25F7 ◷ \\xe2\\x97\\xb7 + U+25F8 ◸ \\xe2\\x97\\xb8 + U+25F9 ◹ \\xe2\\x97\\xb9 + U+25FA ◺ \\xe2\\x97\\xba + U+25FB ◻ \\xe2\\x97\\xbb + U+25FC ◼ \\xe2\\x97\\xbc + U+25FD ◽ \\xe2\\x97\\xbd + U+25FE ◾ \\xe2\\x97\\xbe + U+25FF ◿ \\xe2\\x97\\xbf""" + trans_list = [] + for a in t1.split('\n'): + b = a.split("\t") + trans_list.append((b[2], b[1])) + return trans_list diff --git a/sdp/processors/datasets/commoncrawl/requirements.txt b/sdp/processors/datasets/commoncrawl/requirements.txt new file mode 100644 index 00000000..f0b24650 --- /dev/null +++ b/sdp/processors/datasets/commoncrawl/requirements.txt @@ -0,0 +1,9 @@ +sacrebleu +ffmpeg-python +webvtt-py +fastparquet +pysndfile # conda install -c conda-forge libsndfile==1.0.31 +sonar-space +fairseq2 +huggingsound +pyarrow==12.0.1 \ No newline at end of file diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/huggingface/speech_recognition.py index 8a65cc83..1707cde1 100644 --- a/sdp/processors/huggingface/speech_recognition.py +++ b/sdp/processors/huggingface/speech_recognition.py @@ -13,14 +13,194 @@ # limitations under the License. import json +from collections import Counter from pathlib import Path +from typing import Optional +import numpy as np +import soundfile as sf from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest -from typing import Optional + + +class LangIdWhisper(BaseProcessor): + """ + Processor to get Lang ID using ASR Whisper model from HuggingFace. + + Args: + pretrained_model (str): name of pretrained model on HuggingFace. + output_lang_key (str): field to save language ID result. + device (str): Inference device. + """ + + def __init__( + self, + pretrained_model: str, + output_lang_key: str, + device: str = None, + segment_duration: float = np.inf, + num_segments: int = 1, + random_seed: int = None, + **kwargs, + ): + super().__init__(**kwargs) + try: + import torch + import whisper + except: + raise ImportError("Need to install whisper: pip install -U openai-whisper") + + logger.warning("This is an example processor, for demonstration only. Do not use it for production purposes.") + self.whisper = whisper + self.pretrained_model = pretrained_model + self.device = device + self.output_lang_key = output_lang_key + self.segment_duration = segment_duration + self.num_segments = num_segments + self.random_seed = random_seed + + if self.device is None: + if torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + self.model = whisper.load_model(self.pretrained_model) + + def process(self): + json_list = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(json_list): + pred_lang = self.get_label(item["audio_filepath"]) + item[self.output_lang_key] = pred_lang + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + def get_label(self, path2audio_file): + audio, sample_rate = sf.read(path2audio_file) + audio = np.float32(audio) + + audio_length = audio.shape[0] + + audio_segment_samples = sample_rate * self.segment_duration + segments_in_audio = int(audio_length / audio_segment_samples) + + segment_starts = [] + segment_ends = [] + + np.random.seed(self.random_seed) + + if segments_in_audio <= 1: + segment_starts = [0] + segment_ends = [audio_length] + else: + if segments_in_audio > self.num_segments: + segments_in_audio = self.num_segments + + long_segment_duration = int(audio_length / segments_in_audio) + + for segment_no in range(segments_in_audio): + long_start_segment = long_segment_duration * segment_no + long_end_segment = long_segment_duration * (segment_no + 1) + segment_start = np.random.randint(long_start_segment, long_end_segment - audio_segment_samples) + segment_end = segment_start + audio_segment_samples + segment_starts.append(segment_start) + segment_ends.append(segment_end) + + label_id_list = [] + + n_mels = 80 + + if self.pretrained_model == "large-v3": + n_mels = 128 + + for segment_start, segment_end in zip(segment_starts, segment_ends): + audio_segement = audio[segment_start:segment_end] + audio_segement = self.whisper.pad_or_trim(audio_segement) + mel = self.whisper.log_mel_spectrogram(audio_segement, n_mels) + mel = mel.to(self.device) + _, probs = self.model.detect_language(mel) + lang = max(probs, key=probs.get) + label_id_list.append(lang) + + m_label_id = Counter(label_id_list).most_common(1)[0][0] + return m_label_id + + +class ASRWhisper(BaseProcessor): + """ + Simple example to transcribe using ASR Whisper model from HuggingFace. + There are many ways to improve it: make batch inference, split long files, return predicted language, etc. + + Args: + pretrained_model (str): name of pretrained model on HuggingFace. + output_text_field (str): field to save transcription result. + pad_or_trim_length (int): Audio duration to pad or trim (number of samples). Counted as sample_rate * n_seconds i.e.: 16000*30=480000 + device (str): Inference device. + """ + + def __init__( + self, + pretrained_model: str, + output_text_key: str, + pad_or_trim_length: int = None, + device: str = None, + output_lang_key: str = "lid", + **kwargs, + ): + super().__init__(**kwargs) + try: + import torch + import whisper + except: + raise ImportError("Need to install whisper: pip install -U openai-whisper") + + logger.warning("This is an example processor, for demonstration only. Do not use it for production purposes.") + self.whisper = whisper + self.pretrained_model = pretrained_model + self.output_text_key = output_text_key + self.device = device + self.output_lang_key = output_lang_key + self.pad_or_trim_length = pad_or_trim_length + if self.device is None: + if torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + self.model = whisper.load_model(self.pretrained_model) + self.model.to(self.device) + + def process(self): + json_list = load_manifest(Path(self.input_manifest_file)) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + + with Path(self.output_manifest_file).open('w') as f: + for item in tqdm(json_list): + pred_text, pred_lang = self.whisper_infer(item["audio_filepath"]) + + item[self.output_text_key] = pred_text + item[self.output_lang_key] = pred_lang + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + def whisper_infer(self, audio_path): + audio = self.whisper.load_audio(audio_path) + + audio = self.whisper.pad_or_trim(audio, length=self.pad_or_trim_length) + mel = self.whisper.log_mel_spectrogram(audio) + mel = mel.to(self.device) + + _, probs = self.model.detect_language(mel) + lang = max(probs, key=probs.get) + + options = self.whisper.DecodingOptions(fp16=False) + result = self.whisper.decode(self.model, mel, options) + return result.text, lang + class ASRTransformers(BaseProcessor): """ diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 9328b081..808d395c 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,14 @@ import os import re from typing import Dict, List +import jiwer +import editdistance +import itertools +from tqdm.contrib.concurrent import process_map +from tqdm import tqdm +import json + +import soundfile as sf import soundfile from sox import Transformer @@ -168,6 +176,49 @@ def process_dataset_entry(self, data_entry): return data_list +class SplitLineBySentence(BaseParallelProcessor): + """ + Processor for splitting lines of text into sentences based on a specified pattern. + One line containing N sentences will be transformed into N lines containing one sentence. + + Args: + text_key (str): The field containing the text lines in the dataset. + end_pattern (str): The regular expression pattern to identify sentence boundaries. + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + """ + + def __init__( + self, + text_key: str, + end_pattern: str, + **kwargs, + ): + super().__init__(**kwargs) + self.text_key = text_key + self.pattern = re.compile(end_pattern) + + def process_dataset_entry(self, data_entry): + line = data_entry[self.text_key] + data_list = [] + start = 0 + ends = [m.start() for m in self.pattern.finditer(line)] + if ends: + for end in ends: + sent = line[start : end + 1].strip() + # if sent and sent[0].isupper(): + data = data_entry.copy() + data[self.text_key] = sent + data_list.append(DataEntry(data=data)) + start = end + 1 + if start < len(line): + pass + else: + data = data_entry.copy() + data[self.text_key] = line.strip() + data_list.append(DataEntry(data=data)) + return data_list + + class SoxConvert(BaseParallelProcessor): """ Processor for converting audio files from one format to another using Sox, @@ -592,7 +643,357 @@ def finalize(self, metrics): for word, count in total_counter_sorted.items(): logger.info(f"{word} {count}") super().finalize(metrics) + +class GetWER(BaseParallelProcessor): + """ + Processor that computes the Word Error Rate (WER) between reference text and hypothesis text. + The WER is computed as the Levenshtein distance between the two texts normalized by the + number of words in the reference text. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed WER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed WER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "wer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.word_dist = 0 + self.num_words = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.word_dist += data_entry.metrics.get("word_dist", 0) + self.num_words += data_entry.metrics.get("num_words", 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_words_amount = len(reference_text.split()) + hyp_words_amount = len(hypothesis_text.split()) + + if ref_words_amount == 0 or hyp_words_amount == 0: + if ref_words_amount == hyp_words_amount: + word_dist = 0 + else: + word_dist = ref_words_amount + else: + word_dist_measures = jiwer.compute_measures(reference_text, hypothesis_text) + word_dist = word_dist_measures['substitutions'] + word_dist_measures['insertions'] + word_dist_measures['deletions'] + + wer_value = word_dist / ref_words_amount + data_entry[self.output_metric_field] = round(wer_value * 100, 2) + + return [DataEntry(data=data_entry, metrics = {'word_dist' : word_dist, 'num_words' : ref_words_amount})] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info("Overall Word Error Rate (WER): %.2f%%", self.word_dist / self.num_words * 100) + + +class GetCER(BaseParallelProcessor): + """ + Processor that computes the Character Error Rate (CER) between reference text and hypothesis text. + The CER is computed as the Levenshtein distance between the two texts normalized by the + number of characters in the reference text. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed CER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed CER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "cer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.char_dist = 0 + self.num_chars = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.char_dist += data_entry.metrics.get("char_dist", 0) + self.num_chars += data_entry.metrics.get("num_chars", 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_chars_amount = len(reference_text) + hyp_chars_amount = len(hypothesis_text) + + if ref_chars_amount == 0 or hyp_chars_amount == 0: + if ref_chars_amount == hyp_chars_amount: + char_dist = 0 + else: + char_dist = ref_chars_amount + else: + char_dist = editdistance.eval(reference_text, hypothesis_text) + + cer_value = char_dist / ref_chars_amount + data_entry[self.output_metric_field] = round(cer_value * 100, 2) + + return [DataEntry(data=data_entry, metrics = {'char_dist' : char_dist, 'num_chars' : ref_chars_amount})] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info("Overall Character Error Rate (CER): %.2f%%", self.char_dist / self.num_chars * 100) + + +class GetEdgeCER(BaseParallelProcessor): + """ + Processor that computes the Character Error Rate (CER) for a specified edge of reference + and hypothesis texts. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + edge (str): Specifies whether to compute CER for the 'start' or 'end' edge of the texts. + edge_len (int): Length of the edge window. + output_metric_field (str): Key to put the computed edge CER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed edge CER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + edge: str = "start", + edge_len: int = 10, + output_metric_field: str = "start_cer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.edge = edge + self.edge_len = edge_len + self.output_metric_field = output_metric_field + self.edge_cer_sum = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.edge_cer_sum += data_entry.data.get(self.output_metric_field, 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + if self.edge == "start": + start_idx = 0 + end_idx = self.edge_len + elif self.edge == "end": + start_idx = -self.edge_len + end_idx = -1 + else: + raise ValueError(f"Current `Edge` parameter value ({self.edge}) is incorrect. Please select `start` or `end` edge.") + + reference_text_edge = data_entry[self.reference_text_field][start_idx : end_idx] + hypothesis_text_edge = data_entry[self.hypothesis_text_field][start_idx : end_idx] + + ref_chars_amount = len(reference_text_edge) + hyp_chars_amount = len(hypothesis_text_edge) + + if ref_chars_amount == 0 or hyp_chars_amount == 0: + if ref_chars_amount == hyp_chars_amount: + char_dist = 0 + else: + char_dist = ref_chars_amount + else: + char_dist = editdistance.eval(reference_text_edge, hypothesis_text_edge) + + edge_cer_value = char_dist / ref_chars_amount + data_entry[self.output_metric_field] = round(edge_cer_value * 100, 2) + + return [DataEntry(data=data_entry)] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info(f"Mean {self.edge} Character Error Rate (CER): {round(self.edge_cer_sum / self.number_of_entries, 2)}%") + + +class GetLenDiffRatio(BaseParallelProcessor): + """ + Processor that computes the length difference ratio between reference and hypothesis texts. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed length difference ratio. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed length difference ratio. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "len_diff_ratio", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.words_len_diff_ratio_sum = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.words_len_diff_ratio_sum += data_entry.data.get(self.output_metric_field, 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_words_amount = len(reference_text.split()) + hyp_words_amount = len(hypothesis_text.split()) + + eps = 1e-9 + len_diff_ratio = 1.0 * abs(ref_words_amount - hyp_words_amount) / max(ref_words_amount, eps) + + data_entry[self.output_metric_field] = round(len_diff_ratio * 100, 2) + + return [DataEntry(data=data_entry)] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + logger.info(f"Mean Text Length Difference Ratio (in words): {round(self.words_len_diff_ratio_sum / self.number_of_entries, 2)}%") + class NormalizeText(BaseParallelProcessor): """This processor applies text normalization (TN) to the text. I.e. converts text from written form into its verbalized form. diff --git a/sdp/processors/modify_manifest/data_to_dropbool.py b/sdp/processors/modify_manifest/data_to_dropbool.py index 3c91ba20..7ab333e6 100644 --- a/sdp/processors/modify_manifest/data_to_dropbool.py +++ b/sdp/processors/modify_manifest/data_to_dropbool.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 8fc0a2e0..25266407 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -14,6 +14,7 @@ import os import subprocess +import shutil from pathlib import Path from sdp.processors.base_processor import BaseProcessor @@ -54,12 +55,54 @@ def __init__( def process(self): """This will add "pred_text" key into the output manifest.""" os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + if self.pretrained_model[-5:] == ".nemo": + subprocess.run( + f"python {self.script_path} " + f"model_path={self.pretrained_model} " + f"dataset_manifest={self.input_manifest_file} " + f"output_filename={self.output_manifest_file} " + f"batch_size={self.batch_size} ", + shell=True, + check=True, + ) + else: + subprocess.run( + f"python {self.script_path} " + f"pretrained_name={self.pretrained_model} " + f"dataset_manifest={self.input_manifest_file} " + f"output_filename={self.output_manifest_file} " + f"batch_size={self.batch_size} ", + shell=True, + check=True, + ) + + +class ASRInferenceParallel(BaseProcessor): + def __init__( + self, + pretrained_model: str, + batch_size: int = 32, + devices: int = 2, + **kwargs, + ): + super().__init__(**kwargs) + self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech_parallel.py" + self.pretrained_model = pretrained_model + self.batch_size = batch_size + self.devices = devices + self.output_manifest_dir = self.output_manifest_file.replace(".json", "") + + def process(self): subprocess.run( f"python {self.script_path} " - f"pretrained_name={self.pretrained_model} " - f"dataset_manifest={self.input_manifest_file} " - f"output_filename={self.output_manifest_file} " - f"batch_size={self.batch_size} ", + f"model={self.pretrained_model} " + f"predict_ds.manifest_filepath={self.input_manifest_file} " + f"output_path={self.output_manifest_dir} " + f"predict_ds.batch_size={self.batch_size} " + f"trainer.devices={self.devices} ", shell=True, check=True, - ) \ No newline at end of file + ) + + os.rename(os.path.join(self.output_manifest_dir, "predictions_all.json"), self.output_manifest_file) + shutil.rmtree(self.output_manifest_dir) diff --git a/sdp/processors/nemo/beamsearch_inference.py b/sdp/processors/nemo/beamsearch_inference.py new file mode 100644 index 00000000..3eb5c5fa --- /dev/null +++ b/sdp/processors/nemo/beamsearch_inference.py @@ -0,0 +1,344 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import Levenshtein +import json +import os +import re +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from typing import Dict, List, Optional, Union + +import editdistance +import numpy as np +import torch +from omegaconf import MISSING, OmegaConf +from sklearn.model_selection import ParameterGrid +from tqdm.auto import tqdm + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.submodules import ctc_beam_decoding +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization, TextProcessingConfig +from nemo.core.config import hydra_runner +from nemo.utils import logging + +from sdp.processors.base_processor import BaseProcessor, BaseParallelProcessor, DataEntry + + +def read_manifest(input_manifest_file, encoding): + """Reading the input manifest file. + + .. note:: + This function should be overridden in the "initial" class creating + manifest to read from the original source of data. + """ + if input_manifest_file is None: + raise NotImplementedError("Override this method if the processor creates initial manifest") + + with open(input_manifest_file, "rt", encoding=encoding) as fin: + for line in fin: + yield json.loads(line) + +@dataclass +class EvalBeamSearchNGramConfig: + """ + Evaluate an ASR model with beam search decoding and n-gram KenLM language model. + """ + # # The path of the '.nemo' file of the ASR model or the name of a pretrained model (ngc / huggingface) + model_path: str = MISSING + + # File paths + dataset_manifest: str = MISSING # The manifest file of the evaluation set + preds_output_folder: Optional[str] = None # The optional folder where the predictions are stored + cache_file: Optional[str] = None # The cache file for storing the logprobs of the model + + # Parameters for inference + batch_size: int = 16 # The batch size to calculate log probabilities + beam_batch_size: int = 1 # The batch size to be used for beam search decoding + + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA + # device anyway, and do inference on CPU only if CUDA device is not found. + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) + amp: bool = False + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Beam Search hyperparameters + ctc_decoding: CTCDecodingConfig = field(default_factory=lambda: CTCDecodingConfig( + strategy="flashlight", # gready, beam = pyctcdecode, flashlight + beam = ctc_beam_decoding.BeamCTCInferConfig( + kenlm_path="/mnt/md1/YTDS/ES/lm/lm.kenlm", + beam_size=16, + beam_alpha=0.5, # LM weight + beam_beta=0.5, # length weight + return_best_hypothesis = False, + flashlight_cfg=ctc_beam_decoding.FlashlightConfig( + lexicon_path = "/mnt/md1/YTDS/ES/lm/lm.flashlight_lexicon"), + pyctcdecode_cfg=ctc_beam_decoding.PyCTCDecodeConfig(), + ), + )) + + text_processing: Optional[TextProcessingConfig] = field(default_factory=lambda: TextProcessingConfig( + punctuation_marks = ".,?", + separate_punctuation = False, + do_lowercase = False, + rm_punctuation = False, + )) + + +class BeamsearchTopNInference(BaseProcessor): + """Adds predictions of a text-based punctuation and capitalization (P&C) model. + + Operates on the text in the ``input_text_field``, and saves predictions in + the ``output_text_field``. + + Args: + input_audio_key (str): the text field that will be the input to the P&C model. + output_text_key (str): the text field where the output of the PC model + will be saved. + batch_size (int): the batch sized used by the P&C model. + device (str): the device used by the P&C model. Can be skipped to auto-select. + pretrained_name (str): the pretrained_name of the P&C model. + model_path (str): the model path to the P&C model. + + .. note:: + Either ``pretrained_name`` or ``model_path`` have to be specified. + + Returns: + The same data as in the input manifest with an additional field + containing P&C model's predictions. + """ + + def __init__( + self, + input_audio_key: str, + output_text_key: str, + batch_size: int, + device: Optional[str] = None, + pretrained_name: Optional[str] = None, + model_path: Optional[str] = None, + in_memory_chunksize: int = 100000, + cfg: Optional[EvalBeamSearchNGramConfig] = EvalBeamSearchNGramConfig(), + **kwargs, + ): + super().__init__(**kwargs) + + self.pretrained_name = pretrained_name + self.model_path = model_path + self.input_audio_key = input_audio_key + self.output_text_key = output_text_key + self.device = device + self.batch_size = batch_size + self.in_memory_chunksize=in_memory_chunksize + self.cfg=cfg + + # verify self.pretrained_name/model_path + if self.pretrained_name is None and self.model_path is None: + raise ValueError("pretrained_name and model_path cannot both be None") + if self.pretrained_name is not None and self.model_path is not None: + raise ValueError("pretrained_name and model_path cannot both be specified") + + def _chunk_manifest(self): + """Splits the manifest into smaller chunks defined by ``in_memory_chunksize``. + """ + manifest_chunk = [] + for idx, data_entry in enumerate(read_manifest(self.input_manifest_file, encoding="utf8"), 1): + manifest_chunk.append(data_entry) + if idx % self.in_memory_chunksize == 0: + yield manifest_chunk + manifest_chunk = [] + if len(manifest_chunk) > 0: + yield manifest_chunk + + def process(self): + if self.pretrained_name: + model = EncDecHybridRNNTCTCModel.from_pretrained(self.pretrained_name) + else: + model = EncDecHybridRNNTCTCModel.restore_from(self.model_path) + + if self.device is None: + if torch.cuda.is_available(): + model = model.cuda() + else: + model = model.cpu() + else: + model = model.to(self.device) + + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + + for manifest in self._chunk_manifest(): + + audio_file_paths = [x[self.input_audio_key] for x in manifest] + + + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(decoding_cfg=None, decoder_type="ctc") + else: + model.change_decoding_strategy(None) + + # Override the beam search config with current search candidate configuration + model.cfg.decoding = CTCDecodingConfig( + strategy=self.cfg.ctc_decoding.strategy, + preserve_alignments=self.cfg.ctc_decoding.preserve_alignments, + compute_timestamps=self.cfg.ctc_decoding.compute_timestamps, + word_seperator=self.cfg.ctc_decoding.word_seperator, + ctc_timestamp_type=self.cfg.ctc_decoding.ctc_timestamp_type, + batch_dim_index=self.cfg.ctc_decoding.batch_dim_index, + greedy=self.cfg.ctc_decoding.greedy, + confidence_cfg=self.cfg.ctc_decoding.confidence_cfg, + temperature=self.cfg.ctc_decoding.temperature, + beam = ctc_beam_decoding.BeamCTCInferConfig(beam_size=self.cfg.ctc_decoding.beam.beam_size, + beam_alpha=self.cfg.ctc_decoding.beam.beam_alpha, + beam_beta=self.cfg.ctc_decoding.beam.beam_beta, + kenlm_path=self.cfg.ctc_decoding.beam.kenlm_path, + kenlm_type=self.cfg.ctc_decoding.beam.kenlm_type, + preserve_alignments=self.cfg.ctc_decoding.beam.preserve_alignments, + compute_timestamps=self.cfg.ctc_decoding.beam.compute_timestamps, + flashlight_cfg=self.cfg.ctc_decoding.beam.flashlight_cfg, + pyctcdecode_cfg=self.cfg.ctc_decoding.beam.pyctcdecode_cfg, + return_best_hypothesis=self.cfg.ctc_decoding.beam.return_best_hypothesis), + ) + # Update model's decoding strategy + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(model.cfg.decoding, decoder_type='ctc') + else: + model.change_decoding_strategy(model.cfg.decoding) + + + with torch.no_grad(): + if isinstance(model, EncDecHybridRNNTCTCModel): + model.cur_decoder = 'ctc' + + override_cfg = model.get_transcribe_config() + override_cfg.batch_size = self.batch_size + override_cfg.return_hypotheses = True + + all_hypotheses = model.transcribe(audio_file_paths, override_config=override_cfg) + if type(all_hypotheses) == tuple and len(all_hypotheses) == 2: # if transcriptions form a tuple of (best_hypotheses, all_hypotheses) + all_hypotheses = all_hypotheses[1] + + pred_texts = [] + for hypotheses in all_hypotheses: + pred_text = [hyp.text for hyp in hypotheses] + pred_texts.append(pred_text) + + + for item, t in zip(manifest, pred_texts): + item[self.output_text_key] = t + fout.write(json.dumps(item, ensure_ascii=False) + '\n') + +class RestorePCbyTopN(BaseParallelProcessor): + """ + Adds predictions of a audio-based punctuation and capitalization (P&C) model. + + Args: + text_without_pc_key (str): Key to get path to wav file. + texts_with_pc_key (str): Key to put to audio duration. + output_text_key (str): Key to put to audio duration. + Returns: + All the same fields as in the input manifest plus duration_field + """ + + def __init__( + self, + text_without_pc_key: str, + texts_with_pc_key: str, + output_text_key: str, + punctuation: str, + do_lower: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.text_without_pc_key = text_without_pc_key + self.texts_with_pc_key = texts_with_pc_key + self.output_text_key = output_text_key + self.punctuation = punctuation + self.do_lower = do_lower + + def prepare(self): + if self.punctuation: + self.patterns = re.compile("["+self.punctuation+"]") + + def get_capitalisation_from_target(self, text_input, text_to_fix): + text_input = text_input.strip() + text_to_fix = text_to_fix.strip() + if text_input[0].isupper(): + text_to_fix = text_to_fix[0].upper()+text_to_fix[1:] + + return text_to_fix + + + def process_dataset_entry(self, data_entry): + text_without_pc = data_entry[self.text_without_pc_key] + texts_with_pc = data_entry[self.texts_with_pc_key] + texts = [] + ldists = [] + for text in texts_with_pc: + if self.do_lower: + text = text.lower() + if self.punctuation: + text = self.patterns.sub('', text) + ldist = Levenshtein.distance(text, text_without_pc) + if ldist == 0: + data_entry[self.output_text_key] = text + return [DataEntry(data=data_entry)] + + ldists.append(ldist) + texts.append(text) + + text_with_pc = self.get_capitalisation_from_target(text_without_pc, texts_with_pc[np.argmin(ldists)]) + data_entry[self.output_text_key] = text_with_pc + return [DataEntry(data=data_entry)] + +class ConcatManifests(BaseProcessor): + """Adds predictions of a text-based punctuation and capitalization (P&C) model. + + Operates on the text in the ``input_text_field``, and saves predictions in + the ``output_text_field``. + + Args: + input_audio_key (str): the text field that will be the input to the P&C model. + + .. note:: + Either ``pretrained_name`` or ``model_path`` have to be specified. + + Returns: + The same data as in the input manifest with an additional field + containing P&C model's predictions. + """ + + def __init__( + self, + input_manifest_files: List[str], + encoding: str = "utf8", + ensure_ascii: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.input_manifest_files = input_manifest_files + self.encoding = encoding + self.ensure_ascii = ensure_ascii + + def process(self): + Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) + with open(self.output_manifest_file, "wt", encoding=self.encoding) as fout: + for input_manifest_file in self.input_manifest_files: + for idx, data_entry in enumerate(read_manifest(input_manifest_file, self.encoding)): + fout.write(json.dumps(data_entry, ensure_ascii=self.ensure_ascii) + '\n') diff --git a/sdp/processors/nemo/transcribe_speech_parallel.py b/sdp/processors/nemo/transcribe_speech_parallel.py new file mode 100644 index 00000000..c0af8f97 --- /dev/null +++ b/sdp/processors/nemo/transcribe_speech_parallel.py @@ -0,0 +1,208 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# ASR transcribe/inference with multi-GPU/multi-node support for large datasets +# It supports both tarred and non-tarred datasets +# Arguments +# model: path to a nemo/PTL checkpoint file or name of a pretrained model +# predict_ds: config of the dataset/dataloader +# output_path: path to store the predictions +# return_predictions: whether to return the predictions as output other than writing into the files +# use_cer: whether to calculate the error in terms of CER or use the default WER +# +# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json' + +Example for non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_conformer_ctc_large \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + +Example for Hybrid-CTC/RNNT models with non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_fastconformer_hybrid_large \ + decoder_type=ctc \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + +Example for tarred datasets: + +python transcribe_speech_parallel.py \ + predict_ds.is_tarred=true \ + predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \ + predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \ + ... + +By default the trainer uses all the GPUs available and default precision is FP32. +By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs: + +python transcribe_speech_parallel.py \ + trainer.precision=16 \ + trainer.devices=2 \ + ... + +You may control the dataloader's config by setting the predict_ds: + +python transcribe_speech_parallel.py \ + predict_ds.num_workers=8 \ + predict_ds.min_duration=2.0 \ + predict_ds.sample_rate=16000 \ + model=stt_en_conformer_ctc_small \ + ... + +""" + + +import itertools +import json +import os +from dataclasses import dataclass, is_dataclass +from typing import Optional + +import pytorch_lightning as ptl +import torch +from omegaconf import MISSING, OmegaConf + +from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.core.config import TrainerConfig, hydra_runner +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + + +@dataclass +class ParallelTranscriptionConfig: + model: Optional[str] = None # name + predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4) + output_path: str = MISSING + + # when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done + return_predictions: bool = False + use_cer: bool = False + + # decoding strategy for RNNT models + rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() + + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models + decoder_type: Optional[str] = None + # att_context_size can be set for cache-aware streaming models with multiple look-aheads + att_context_size: Optional[list] = None + + trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp") + + +def match_train_config(predict_ds, train_ds): + # It copies the important configurations from the train dataset of the model + # into the predict_ds to be used for prediction. It is needed to match the training configurations. + if train_ds is None: + return + + predict_ds.sample_rate = train_ds.get("sample_rate", 16000) + cfg_name_list = [ + "int_values", + "use_start_end_token", + "blank_index", + "unk_index", + "normalize", + "parser", + "eos_id", + "bos_id", + "pad_id", + ] + + if is_dataclass(predict_ds): + predict_ds = OmegaConf.structured(predict_ds) + for cfg_name in cfg_name_list: + if hasattr(train_ds, cfg_name): + setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name)) + + return predict_ds + + +@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig) +def main(cfg: ParallelTranscriptionConfig): + if cfg.model.endswith(".nemo"): + logging.info("Attempting to initialize from .nemo file") + model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") + elif cfg.model.endswith(".ckpt"): + logging.info("Attempting to initialize from .ckpt file") + model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") + else: + logging.info( + "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" + ) + model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") + + if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None: + model.change_decoding_strategy(decoder_type=cfg.decoder_type) + + trainer = ptl.Trainer(**cfg.trainer) + + cfg.predict_ds.return_sample_id = True + cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + data_loader = model._setup_dataloader_from_config(cfg.predict_ds) + + os.makedirs(cfg.output_path, exist_ok=True) + # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. + global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0)) + output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") + predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file) + trainer.callbacks.extend([predictor_writer]) + + predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions) + if predictions is not None: + predictions = list(itertools.chain.from_iterable(predictions)) + samples_num = predictor_writer.close_output_file() + + logging.info( + f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." + ) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + samples_num = 0 + pred_text_list = [] + text_list = [] + if is_global_rank_zero(): + output_file = os.path.join(cfg.output_path, f"predictions_all.json") + logging.info(f"Prediction files are being aggregated in {output_file}.") + with open(output_file, 'w') as outf: + for rank in range(trainer.world_size): + input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") + with open(input_file, 'r') as inpf: + lines = inpf.readlines() + for line in lines: + item = json.loads(line) + pred_text_list.append(item["pred_text"]) + text_list.append(item["text"]) + outf.write(json.dumps(item) + "\n") + samples_num += 1 + wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) + logging.info( + f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." + ) + logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer)) + + +if __name__ == '__main__': + main() diff --git a/sdp/utils/common.py b/sdp/utils/common.py index 6d9c4fba..0483809b 100644 --- a/sdp/utils/common.py +++ b/sdp/utils/common.py @@ -36,6 +36,26 @@ def load_manifest(manifest: Path) -> List[Dict[str, Union[str, float]]]: return result +def ffmpeg_convert(input_file: str, output_wav: str, sample_rate: int = 0, num_channels: int = 1): + process_args = [ + "ffmpeg", + "-i", + input_file, + '-ac', + str(num_channels), + "-map", + "0:a", + "-c:a", + "pcm_s16le", + "-y", + output_wav, + ] + if sample_rate: + process_args = process_args[:-1] + process_args.extend(["-ar", str(sample_rate), output_wav]) + return subprocess.run(process_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + def download_file(source_url: str, target_directory: str, verbose=True): # make sure target_directory is an absolute path to avoid bugs when we change directories to download data later target_directory = os.path.abspath(target_directory)