Skip to content

Commit 412bdc3

Browse files
naymaraqnaymaraq
andauthored
Add support for streaming speech translation (#15132)
* draft version of cascaded streaming speech translation feature Signed-off-by: naymaraq <[email protected]> * Extract the first line of LLM response Signed-off-by: naymaraq <[email protected]> * add device, device_id, batch_size for llm translator Signed-off-by: naymaraq <[email protected]> * improved translation performance by passing extra context Signed-off-by: naymaraq <[email protected]> * better prompt for translation Signed-off-by: naymaraq <[email protected]> * fix docstrings Signed-off-by: naymaraq <[email protected]> * Apply isort and black reformatting Signed-off-by: naymaraq <[email protected]> * fix typos Signed-off-by: naymaraq <[email protected]> * Apply isort and black reformatting Signed-off-by: naymaraq <[email protected]> * more explanation for wait-k parameter Signed-off-by: naymaraq <[email protected]> * fix return type Signed-off-by: naymaraq <[email protected]> * minor changes, rm device arg Signed-off-by: naymaraq <[email protected]> * correct explanation of waitk param Signed-off-by: naymaraq <[email protected]> * add Dockerfile for streaming speech translation with vLLM Signed-off-by: naymaraq <[email protected]> * disable tqdm for llm generation Signed-off-by: naymaraq <[email protected]> * add evaluation metrics Signed-off-by: naymaraq <[email protected]> * added LAAL calculation for LLM-based streaming ST Signed-off-by: naymaraq <[email protected]> * bugfix in delay calculation Signed-off-by: naymaraq <[email protected]> * Apply isort and black reformatting Signed-off-by: naymaraq <[email protected]> * cleanup Signed-off-by: naymaraq <[email protected]> --------- Signed-off-by: naymaraq <[email protected]> Signed-off-by: naymaraq <[email protected]> Co-authored-by: naymaraq <[email protected]> Co-authored-by: naymaraq <[email protected]>
1 parent 676a368 commit 412bdc3

26 files changed

+1127
-65
lines changed

examples/asr/asr_streaming_inference/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Beyond streaming ASR, the script also supports:
66

77
* **Inverse Text Normalization (ITN)**
88
* **End-of-Utterance (EoU) Detection**
9+
* **Streaming Speech Translation (requires vLLM installation)**
910
* **Word-level and Segment-level Output**
1011

1112
All related configurations can be found in the `../conf/asr_streaming_inference/` directory.

examples/asr/asr_streaming_inference/asr_streaming_infer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
output_filename=<path to output jsonfile> \
3131
lang=en \
3232
enable_pnc=False \
33-
enable_itn=True \
33+
enable_itn=False \
34+
enable_nmt=False \
3435
asr_output_granularity=segment \
3536
...
3637
# See ../conf/asr_streaming_inference/*.yaml for all available options
@@ -45,9 +46,9 @@
4546

4647
import hydra
4748

48-
4949
from nemo.collections.asr.inference.factory.pipeline_builder import PipelineBuilder
5050
from nemo.collections.asr.inference.utils.manifest_io import calculate_duration, dump_output, get_audio_filepaths
51+
from nemo.collections.asr.inference.utils.pipeline_eval import calculate_pipeline_laal, evaluate_pipeline
5152
from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar
5253
from nemo.utils import logging
5354

@@ -69,8 +70,11 @@ def main(cfg):
6970
logging.setLevel(cfg.log_level)
7071

7172
# Reading audio filepaths
72-
audio_filepaths = get_audio_filepaths(cfg.audio_file, sort_by_duration=True)
73+
audio_filepaths, manifest = get_audio_filepaths(cfg.audio_file, sort_by_duration=True)
7374
logging.info(f"Found {len(audio_filepaths)} audio files")
75+
if manifest:
76+
keys = list(manifest[0].keys())
77+
logging.info(f"Found {len(keys)} keys in the input manifest: {keys}")
7478

7579
# Build the pipeline
7680
pipeline = PipelineBuilder.build_pipeline(cfg)
@@ -82,13 +86,20 @@ def main(cfg):
8286
exec_dur = time() - start
8387

8488
# Calculate RTFX
85-
data_dur = calculate_duration(audio_filepaths)
89+
data_dur, durations = calculate_duration(audio_filepaths)
8690
rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf')
8791
logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
8892

93+
# Calculate LAAL
94+
laal = calculate_pipeline_laal(output, durations, manifest, cfg)
95+
if laal is not None:
96+
logging.info(f"LAAL: {laal:.2f}ms")
97+
8998
# Dump the transcriptions to a output file
90-
dump_output(output, cfg.output_filename, cfg.output_dir)
91-
logging.info(f"Transcriptions written to {cfg.output_filename}")
99+
dump_output(output, cfg.output_filename, cfg.output_dir, manifest)
100+
101+
# Evaluate the pipeline
102+
evaluate_pipeline(cfg.output_filename, cfg)
92103
logging.info("Done!")
93104

94105

examples/asr/conf/asr_streaming_inference/buffered_ctc.yaml

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@ itn:
2222
n_jobs: 16 # Number of parallel jobs for ITN processing
2323

2424

25+
# ================================
26+
# Neural Machine Translation Configuration
27+
# ================================
28+
nmt:
29+
model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name
30+
source_language: "English" # Source language code
31+
target_language: "Russian" # Target language code
32+
waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations.
33+
device: cuda # Device for translation: 'cuda'. 'cpu' is not supported.
34+
device_id: 1 # GPU device ID for translation
35+
batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size
36+
llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details
37+
dtype: "auto" # Compute precision
38+
seed: 42 # The seed to initialize the random number generator for sampling
39+
sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details
40+
max_tokens: 100 # Maximum number of tokens to generate with LLM
41+
temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy)
42+
top_p: 0.9 # The cumulative probability threshold for nucleus sampling
43+
seed: 42 # The seed to initialize the random number generator for sampling
44+
45+
2546
# ========================
2647
# Confidence estimation
2748
# ========================
@@ -67,14 +88,36 @@ asr_decoding_type: ctc # Decoding method: ctc or rnnt
6788

6889

6990
# ========================
70-
# Runtime arguments defined at runtime via command line
91+
# Runtime arguments defined at runtime via command line
7192
# ========================
7293
audio_file: null # Path to audio file, directory, or manifest JSON
7394
output_filename: null # Path to output transcription JSON file
7495
output_dir: null # Directory to save time-aligned output
7596
enable_pnc: false # Whether to apply punctuation & capitalization
7697
enable_itn: false # Whether to apply inverse text normalization
98+
enable_nmt: false # Whether to apply neural machine translation
7799
asr_output_granularity: segment # Output granularity: word or segment
78100
cache_dir: null # Directory to store cache (e.g., .far files)
79101
lang: null # Language code for ASR model
80102
return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer
103+
calculate_wer: true # Whether to calculate WER
104+
calculate_bleu: true # Whether to calculate BLEU score
105+
106+
107+
# ========================
108+
# Metrics
109+
# ========================
110+
metrics:
111+
asr:
112+
gt_text_attr_name: text # Attribute name for ground truth text
113+
clean_groundtruth_text: false # Whether to clean ground truth text
114+
langid: en # Language code for text normalization; only "en" is supported
115+
use_cer: false # Whether to use character error rate
116+
ignore_capitalization: true # Whether to ignore capitalization
117+
ignore_punctuation: true # Whether to ignore punctuation
118+
strip_punc_space: false # Whether to strip punctuation and space
119+
nmt:
120+
gt_text_attr_name: answer # Attribute name for ground truth text
121+
ignore_capitalization: false # Whether to ignore capitalization
122+
ignore_punctuation: false # Whether to ignore punctuation
123+
strip_punc_space: false # Whether to strip punctuation and space

examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ itn:
3939
n_jobs: 16 # Number of parallel jobs for ITN processing
4040

4141

42+
# ================================
43+
# Neural Machine Translation Configuration
44+
# ================================
45+
nmt:
46+
model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name
47+
source_language: "English" # Source language code
48+
target_language: "Russian" # Target language code
49+
waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations.
50+
device: cuda # Device for translation: 'cuda'. 'cpu' is not supported.
51+
device_id: 1 # GPU device ID for translation
52+
batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size
53+
llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details
54+
dtype: "auto" # Compute precision
55+
seed: 42 # The seed to initialize the random number generator for sampling
56+
sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details
57+
max_tokens: 100 # Maximum number of tokens to generate with LLM
58+
temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy)
59+
top_p: 0.9 # The cumulative probability threshold for nucleus sampling
60+
seed: 42 # The seed to initialize the random number generator for sampling
61+
62+
4263
# ========================
4364
# Confidence estimation
4465
# ========================
@@ -85,14 +106,35 @@ asr_decoding_type: rnnt # Decoding method: ctc or rnnt
85106

86107

87108
# ========================
88-
# Runtime arguments defined at runtime via command line
109+
# Runtime arguments defined at runtime via command line
89110
# ========================
90111
audio_file: null # Path to audio file, directory, or manifest JSON
91112
output_filename: null # Path to output transcription JSON file
92113
output_dir: null # Directory to save time-aligned output
93114
enable_pnc: false # Whether to apply punctuation & capitalization
94115
enable_itn: false # Whether to apply inverse text normalization
116+
enable_nmt: false # Whether to apply neural machine translation
95117
asr_output_granularity: segment # Output granularity: word or segment
96118
cache_dir: null # Directory to store cache (e.g., .far files)
97119
lang: null # Language code for ASR model
98120
return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer
121+
calculate_wer: true # Whether to calculate WER
122+
calculate_bleu: true # Whether to calculate BLEU score
123+
124+
# ========================
125+
# Metrics
126+
# ========================
127+
metrics:
128+
asr:
129+
gt_text_attr_name: text # Attribute name for ground truth text
130+
clean_groundtruth_text: false # Whether to clean ground truth text
131+
langid: en # Language code for text normalization; only "en" is supported
132+
use_cer: false # Whether to use character error rate
133+
ignore_capitalization: true # Whether to ignore capitalization
134+
ignore_punctuation: true # Whether to ignore punctuation
135+
strip_punc_space: false # Whether to strip punctuation and space
136+
nmt:
137+
gt_text_attr_name: answer # Attribute name for ground truth text
138+
ignore_capitalization: false # Whether to ignore capitalization
139+
ignore_punctuation: false # Whether to ignore punctuation
140+
strip_punc_space: false # Whether to strip punctuation and space

examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@ itn:
2222
n_jobs: 16 # Number of parallel jobs for ITN processing
2323

2424

25+
# ================================
26+
# Neural Machine Translation Configuration
27+
# ================================
28+
nmt:
29+
model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name
30+
source_language: "English" # Source language code
31+
target_language: "Russian" # Target language code
32+
waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations.
33+
device: cuda # Device for translation: 'cuda'. 'cpu' is not supported.
34+
device_id: 1 # GPU device ID for translation
35+
batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size
36+
llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details
37+
dtype: "auto" # Compute precision
38+
seed: 42 # The seed to initialize the random number generator for sampling
39+
sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details
40+
max_tokens: 100 # Maximum number of tokens to generate with LLM
41+
temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy)
42+
top_p: 0.9 # The cumulative probability threshold for nucleus sampling
43+
seed: 42 # The seed to initialize the random number generator for sampling
44+
45+
2546
# ========================
2647
# Confidence estimation
2748
# ========================
@@ -74,7 +95,29 @@ output_filename: null # Path to output transcription JSO
7495
output_dir: null # Directory to save time-aligned output
7596
enable_pnc: false # Whether to apply punctuation & capitalization
7697
enable_itn: false # Whether to apply inverse text normalization
98+
enable_nmt: false # Whether to apply neural machine translation
7799
asr_output_granularity: segment # Output granularity: word or segment
78100
cache_dir: null # Directory to store cache (e.g., .far files)
79101
lang: null # Language code for ASR model
80102
return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer
103+
calculate_wer: true # Whether to calculate WER
104+
calculate_bleu: true # Whether to calculate BLEU score
105+
106+
107+
# ========================
108+
# Metrics
109+
# ========================
110+
metrics:
111+
asr:
112+
gt_text_attr_name: text # Attribute name for ground truth text
113+
clean_groundtruth_text: false # Whether to clean ground truth text
114+
langid: en # Language code for text normalization; only "en" is supported
115+
use_cer: false # Whether to use character error rate
116+
ignore_capitalization: true # Whether to ignore capitalization
117+
ignore_punctuation: true # Whether to ignore punctuation
118+
strip_punc_space: false # Whether to strip punctuation and space
119+
nmt:
120+
gt_text_attr_name: answer # Attribute name for ground truth text
121+
ignore_capitalization: false # Whether to ignore capitalization
122+
ignore_punctuation: false # Whether to ignore punctuation
123+
strip_punc_space: false # Whether to strip punctuation and space

examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@ itn:
3838
n_jobs: 16 # Number of parallel jobs for ITN processing
3939

4040

41+
# ================================
42+
# Neural Machine Translation Configuration
43+
# ================================
44+
nmt:
45+
model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name
46+
source_language: "English" # Source language code
47+
target_language: "Russian" # Target language code
48+
waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations.
49+
device: cuda # Device for translation: 'cuda'. 'cpu' is not supported.
50+
device_id: 1 # GPU device ID for translation
51+
batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size
52+
llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details
53+
dtype: "auto" # Compute precision
54+
seed: 42 # The seed to initialize the random number generator for sampling
55+
sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details
56+
max_tokens: 100 # Maximum number of tokens to generate with LLM
57+
temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy)
58+
top_p: 0.9 # The cumulative probability threshold for nucleus sampling
59+
seed: 42 # The seed to initialize the random number generator for sampling
60+
61+
4162
# ========================
4263
# Confidence estimation
4364
# ========================
@@ -84,14 +105,36 @@ asr_decoding_type: rnnt # Decoding method: ctc or rnnt
84105

85106

86107
# ========================
87-
# Runtime arguments defined at runtime via command line
108+
# Runtime arguments defined at runtime via command line
88109
# ========================
89110
audio_file: null # Path to audio file, directory, or manifest JSON
90111
output_filename: null # Path to output transcription JSON file
91112
output_dir: null # Directory to save time-aligned output
92113
enable_pnc: false # Whether to apply punctuation & capitalization
93114
enable_itn: false # Whether to apply inverse text normalization
115+
enable_nmt: false # Whether to apply neural machine translation
94116
asr_output_granularity: segment # Output granularity: word or segment
95117
cache_dir: null # Directory to store cache (e.g., .far files)
96118
lang: null # Language code for ASR model
97119
return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer
120+
calculate_wer: true # Whether to calculate WER
121+
calculate_bleu: true # Whether to calculate BLEU score
122+
123+
124+
# ========================
125+
# Metrics
126+
# ========================
127+
metrics:
128+
asr:
129+
gt_text_attr_name: text # Attribute name for ground truth text
130+
clean_groundtruth_text: false # Whether to clean ground truth text
131+
langid: en # Language code for text normalization; only "en" is supported
132+
use_cer: false # Whether to use character error rate
133+
ignore_capitalization: true # Whether to ignore capitalization
134+
ignore_punctuation: true # Whether to ignore punctuation
135+
strip_punc_space: false # Whether to strip punctuation and space
136+
nmt:
137+
gt_text_attr_name: answer # Attribute name for ground truth text
138+
ignore_capitalization: false # Whether to ignore capitalization
139+
ignore_punctuation: false # Whether to ignore punctuation
140+
strip_punc_space: false # Whether to strip punctuation and space

0 commit comments

Comments
 (0)