Skip to content

Commit b944578

Browse files
committed
Merge remote-tracking branch 'nvidia/main' into tts_2512_removetorchaudio
2 parents 0d7c14c + 527b8c4 commit b944578

File tree

64 files changed

+3457
-439
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+3457
-439
lines changed

.github/workflows/cicd-main-nemo2.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ jobs:
201201
runner: self-hosted-azure-gpus-1
202202
- script: L2_NeMo_2_Auto_Configurator_bert_TP1_PP1_MBS124
203203
runner: self-hosted-azure-gpus-1
204-
- script: L2_NeMo_2_Auto_Configurator_t5_TP1_PP1_MBS124
205-
runner: self-hosted-azure-gpus-1
204+
# - script: L2_NeMo_2_Auto_Configurator_t5_TP1_PP1_MBS124 #skipping t5 hanging tests
205+
# runner: self-hosted-azure-gpus-1
206206
- script: L2_NeMo_2_Auto_Configurator_callbacks
207207
runner: self-hosted-azure-gpus-1
208208
- script: L2_NeMo_2_Conversion_Test_Baichuan2

.github/workflows/cicd-main-speech.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ jobs:
131131
script: L2_Speech_Transcription_Speech_to_Text_Cache_Aware_Infer
132132
- runner: self-hosted-azure
133133
script: L2_Speech_Transcription_Streaming_Inference
134+
- runner: self-hosted-azure
135+
script: L2_Speech_Transcription_Speech_to_Text_Inference_Boost_GT
136+
- runner: self-hosted-azure
137+
script: L2_Speech_Transcription_Speech_to_Text_Transcribe_Boost_GT
134138
- runner: self-hosted-azure
135139
script: L2_Speech_Transcription_Canary_Transcribe_Full_Manifest
136140
- runner: self-hosted-azure

examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@
6565
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = alloc_conf
6666

6767

68+
import librosa
6869
import lightning.pytorch as pl
6970
import torch
7071
from omegaconf import OmegaConf, open_dict
7172
from torch.utils.data import DataLoader
7273
from tqdm.auto import tqdm
7374

7475
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
76+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
7577
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
7678
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
7779
GreedyBatchedLabelLoopingComputerBase,
@@ -95,6 +97,7 @@
9597
)
9698
from nemo.core.config import hydra_runner
9799
from nemo.utils import logging
100+
from nemo.utils.timers import SimpleTimer
98101

99102

100103
def make_divisible_by(num, factor: int) -> int:
@@ -113,6 +116,7 @@ class TranscriptionConfig:
113116
pretrained_name: Optional[str] = None # Name of a pretrained model
114117
audio_dir: Optional[str] = None # Path to a directory which contains audio files
115118
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
119+
sort_by_duration: bool = True # sort manifest/audio files by duration (descending)
116120

117121
# General configs
118122
output_filename: Optional[str] = None
@@ -145,6 +149,8 @@ class TranscriptionConfig:
145149

146150
# Decoding strategy for RNNT models
147151
decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)
152+
# Per-utterance biasing with biasing config in the manifest
153+
use_per_stream_biasing: bool = False
148154

149155
timestamps: bool = False # output timestamps
150156

@@ -154,6 +160,8 @@ class TranscriptionConfig:
154160
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
155161
use_cer: bool = False
156162

163+
calculate_rtfx: bool = False
164+
157165

158166
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
159167
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
@@ -216,6 +224,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
216224
asr_model = asr_model.to(asr_model.device)
217225
asr_model.to(compute_dtype)
218226

227+
use_per_stream_biasing = cfg.use_per_stream_biasing
228+
219229
# Change Decoding Config
220230
with open_dict(cfg.decoding):
221231
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
@@ -226,6 +236,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
226236
cfg.decoding.greedy.preserve_alignments = False
227237
cfg.decoding.fused_batch_size = -1 # temporarily stop fused batch during inference.
228238
cfg.decoding.beam.return_best_hypothesis = True # return and write the best hypothsis only
239+
if use_per_stream_biasing:
240+
cfg.decoding.greedy.enable_per_stream_biasing = use_per_stream_biasing
229241

230242
# Setup decoding strategy
231243
if hasattr(asr_model, 'change_decoding_strategy'):
@@ -250,6 +262,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
250262
assert filepaths is not None
251263
records = [{"audio_filepath": audio_file} for audio_file in filepaths]
252264

265+
if cfg.sort_by_duration:
266+
filepath2order = dict()
267+
for i, record in enumerate(records):
268+
if "duration" not in record:
269+
record["duration"] = librosa.get_duration(path=record["audio_filepath"])
270+
filepath2order[record["audio_filepath"]] = i
271+
records.sort(key=lambda record: record["duration"], reverse=True)
272+
253273
asr_model.preprocessor.featurizer.dither = 0.0
254274
asr_model.preprocessor.featurizer.pad_to = 0
255275
asr_model.eval()
@@ -289,8 +309,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
289309
latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate
290310
logging.info(f"Theoretical latency: {latency_secs:.2f} seconds")
291311

312+
biasing_requests: list[BiasingRequestItemConfig | None] | None
313+
if use_per_stream_biasing:
314+
biasing_requests = [
315+
(
316+
BiasingRequestItemConfig(
317+
**OmegaConf.to_container(
318+
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
319+
)
320+
)
321+
if "biasing_request" in record
322+
else None
323+
)
324+
for record in records
325+
]
326+
else:
327+
biasing_requests = None
328+
292329
audio_dataset = SimpleAudioDataset(
293-
audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate
330+
audio_filenames=[record["audio_filepath"] for record in records],
331+
sample_rate=audio_sample_rate,
332+
biasing_requests=biasing_requests,
294333
)
295334
audio_dataloader = DataLoader(
296335
dataset=audio_dataset,
@@ -302,9 +341,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
302341
in_order=True,
303342
)
304343

344+
timer = SimpleTimer()
305345
with torch.no_grad(), torch.inference_mode():
306346
all_hyps = []
307347
audio_data: AudioBatch
348+
timer.start(device=map_location)
308349
for audio_data in tqdm(audio_dataloader):
309350
# get audio
310351
# NB: preprocessor runs on torch.float32, no need to cast dtype here
@@ -313,8 +354,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
313354
batch_size = audio_batch.shape[0]
314355
device = audio_batch.device
315356

316-
# decode audio by chunks
357+
# add biasing requests to the decoder
358+
if use_per_stream_biasing:
359+
multi_biasing_ids = torch.full([batch_size], fill_value=-1, dtype=torch.long, device=map_location)
360+
if audio_data.biasing_requests is not None:
361+
for batch_i, request in enumerate(audio_data.biasing_requests):
362+
if request is not None:
363+
biasing_model = request.get_model(tokenizer=asr_model.tokenizer)
364+
if biasing_model is not None:
365+
multi_model_id = decoding_computer.biasing_multi_model.add_model(biasing_model)
366+
request.multi_model_id = multi_model_id
367+
multi_biasing_ids[batch_i] = multi_model_id
368+
else:
369+
multi_biasing_ids = None
317370

371+
# decode audio by chunks
318372
current_batched_hyps: BatchedHyps | None = None
319373
state = None
320374
left_sample = 0
@@ -368,6 +422,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
368422
encoder_context_batch.chunk,
369423
),
370424
prev_batched_state=state,
425+
multi_biasing_ids=multi_biasing_ids,
371426
)
372427
# merge hyps with previous hyps
373428
if current_batched_hyps is None:
@@ -380,7 +435,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
380435
left_sample = right_sample
381436
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk
382437

438+
# remove biasing requests from the decoder
439+
if use_per_stream_biasing and audio_data.biasing_requests is not None:
440+
for request in audio_data.biasing_requests:
441+
if request is not None and request.multi_model_id is not None:
442+
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
443+
request.multi_model_id = None
383444
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
445+
timer.stop(device=map_location)
384446

385447
# convert text
386448
for i, hyp in enumerate(all_hyps):
@@ -394,11 +456,26 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
394456
)
395457
all_hyps[i] = hyp
396458

459+
if cfg.sort_by_duration:
460+
# restore order for all_hyps and records (all_hyps are consistent with records)
461+
order_restored = sorted(
462+
zip(records, all_hyps), key=lambda records_hyps: filepath2order[records_hyps[0]["audio_filepath"]]
463+
)
464+
records, all_hyps = map(list, zip(*order_restored))
465+
397466
output_filename, pred_text_attr_name = write_transcription(
398467
all_hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=cfg.timestamps
399468
)
400469
logging.info(f"Finished writing predictions to {output_filename}!")
401470

471+
if cfg.calculate_rtfx:
472+
durations = [
473+
record["duration"] if "duration" in record else librosa.get_duration(path=record["audio_filepath"])
474+
for record in records
475+
]
476+
rtfx = sum(durations) / timer.total_sec()
477+
logging.info(f"RTFx: {rtfx:.2f}")
478+
402479
if cfg.calculate_wer:
403480
output_manifest_w_wer, total_res, _ = cal_write_wer(
404481
pred_manifest=output_filename,

examples/asr/asr_streaming_inference/asr_streaming_infer.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@
4242
"""
4343

4444

45-
from time import time
46-
4745
import hydra
46+
from omegaconf import OmegaConf
4847

4948
from nemo.collections.asr.inference.factory.pipeline_builder import PipelineBuilder
49+
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
5050
from nemo.collections.asr.inference.utils.manifest_io import calculate_duration, dump_output, get_audio_filepaths
5151
from nemo.collections.asr.inference.utils.pipeline_eval import calculate_pipeline_laal, evaluate_pipeline
5252
from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar
53+
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
5354
from nemo.utils import logging
55+
from nemo.utils.timers import SimpleTimer
5456

5557
# disable nemo_text_processing logging
5658
try:
@@ -80,15 +82,36 @@ def main(cfg):
8082
pipeline = PipelineBuilder.build_pipeline(cfg)
8183
progress_bar = TQDMProgressBar()
8284

85+
# Add biasing requests
86+
if manifest:
87+
options = [
88+
ASRRequestOptions(
89+
biasing_cfg=(
90+
BiasingRequestItemConfig(
91+
**OmegaConf.to_container(
92+
OmegaConf.merge(OmegaConf.structured(BiasingRequestItemConfig), record["biasing_request"])
93+
)
94+
)
95+
if "biasing_request" in record
96+
else None
97+
)
98+
)
99+
for record in manifest
100+
]
101+
else:
102+
options = None
103+
83104
# Run the pipeline
84-
start = time()
85-
output = pipeline.run(audio_filepaths, progress_bar=progress_bar)
86-
exec_dur = time() - start
105+
timer = SimpleTimer()
106+
timer.start(pipeline.device)
107+
output = pipeline.run(audio_filepaths, progress_bar=progress_bar, options=options)
108+
timer.stop(pipeline.device)
109+
exec_dur = timer.total_sec()
87110

88-
# Calculate RTFX
111+
# Calculate RTFx
89112
data_dur, durations = calculate_duration(audio_filepaths)
90113
rtfx = data_dur / exec_dur if exec_dur > 0 else float('inf')
91-
logging.info(f"RTFX: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
114+
logging.info(f"RTFx: {rtfx:.2f} ({data_dur:.2f}s / {exec_dur:.2f}s)")
92115

93116
# Calculate LAAL
94117
laal = calculate_pipeline_laal(output, durations, manifest, cfg)

examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ asr:
1313
fused_batch_size: -1
1414
greedy:
1515
use_cuda_graph_decoder: true
16+
enable_per_stream_biasing: true # Per-stream biasing in decoder
1617
max_symbols: 10
1718
# n-gram LM
1819
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
@@ -22,7 +23,11 @@ asr:
2223
model_path: null # The path to built '.nemo' boosting tree model
2324
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
2425
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
25-
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
26+
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
27+
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
28+
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
29+
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
30+
# used with `key_phrases_file` and `key_phrases_list`
2631
boosting_tree_alpha: 0.0
2732

2833

examples/asr/conf/asr_streaming_inference/cache_aware_ctc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ streaming:
7575
use_cache: true # Whether to use cache for streaming
7676
use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer
7777
chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames.
78-
request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming
78+
request_type: frame # Type of request: frame or feature_buffer
7979
num_slots: 1024 # Number of slots in the context manager: must be >= batch_size
8080

8181

examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# ASR Configuration
33
# ================================
44
asr:
5-
model_name: stt_en_fastconformer_hybrid_large_streaming_multi # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path
5+
model_name: nvidia/nemotron-speech-streaming-en-0.6b # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path
66
device: cuda # Device for inference: 'cuda' or 'cpu'
77
device_id: 0 # GPU device ID
88
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
@@ -13,6 +13,7 @@ asr:
1313
fused_batch_size: -1
1414
greedy:
1515
use_cuda_graph_decoder: false # Disabled due to issues with decoding
16+
enable_per_stream_biasing: false # Per-stream biasing in decoder
1617
max_symbols: 10
1718
# n-gram LM
1819
ngram_lm_model: null # The path to built '.nemo' NGPU-LM model
@@ -22,7 +23,11 @@ asr:
2223
model_path: null # The path to built '.nemo' boosting tree model
2324
key_phrases_file: null # The path to the context-biasing list file (one phrase per line)
2425
key_phrases_list: null # The list of context-biasing phrases ['word1', 'word2', 'word3', ...]
25-
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer)
26+
key_phrase_items_list: null # The list of context-biasing phrases with custom fields
27+
# in CLI: [{phrase:"word1",lang:en},{phrase:"frase dos",lang:es}]
28+
# in code: [PhraseItem(phrase="word1, lang="en"), PhraseItem(phrase2="frase dos", lang="es")]
29+
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
30+
# used with `key_phrases_file` and `key_phrases_list`
2631
boosting_tree_alpha: 0.0 # Weight of the boosting tree
2732

2833
# ==========================================
@@ -85,14 +90,14 @@ endpointing:
8590
# ========================
8691
streaming:
8792
sample_rate: 16000 # Audio sample rate in Hz
88-
batch_size: 256 # Number of audio frames per batch
93+
batch_size: 64 # Number of audio frames per batch
8994
word_boundary_tolerance: 4 # Tolerance for word boundaries
9095
att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0]
9196
use_cache: true # Whether to use cache for streaming
9297
use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer
9398
chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames.
94-
request_type: frame # Type of request: frame, only frame is supported for cache-aware streaming
95-
num_slots: 1024 # Number of slots in the context manager: must be >= batch_size
99+
request_type: frame # Type of request: frame or feature_buffer
100+
num_slots: 256 # Number of slots in the context manager: must be >= batch_size
96101

97102

98103
# ========================

examples/tts/evalset_config.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
"manifest_path": "/home/TestData/an4_dataset/an4_val_context_v1.json",
44
"audio_dir": "/",
55
"feature_dir": null
6+
},
7+
"an4_val_tiny_ci": {
8+
"manifest_path": "/home/TestData/an4_dataset/an4_val_context_v1_tiny.json",
9+
"audio_dir": "/",
10+
"feature_dir": null
611
}
712
}
813

examples/tts/magpietts_inference.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,14 @@ def create_argument_parser() -> argparse.ArgumentParser:
504504
return parser
505505

506506

507-
def main():
508-
"""Main entry point."""
507+
def main(argv=None):
508+
"""Entry point for MagpieTTS inference and evaluation.
509+
510+
Args:
511+
argv: Command-line arguments. If None, uses sys.argv.
512+
"""
509513
parser = create_argument_parser()
510-
args = parser.parse_args()
514+
args = parser.parse_args(argv)
511515

512516
dataset_meta_info = load_evalset_config(args.datasets_json_path)
513517
datasets = filter_datasets(dataset_meta_info, args.datasets)

0 commit comments

Comments
 (0)