6565 os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = alloc_conf
6666
6767
68+ import librosa
6869import lightning .pytorch as pl
6970import torch
7071from omegaconf import OmegaConf , open_dict
7172from torch .utils .data import DataLoader
7273from tqdm .auto import tqdm
7374
7475from nemo .collections .asr .models import EncDecHybridRNNTCTCModel , EncDecRNNTModel
76+ from nemo .collections .asr .parts .context_biasing .biasing_multi_model import BiasingRequestItemConfig
7577from nemo .collections .asr .parts .submodules .rnnt_decoding import RNNTDecodingConfig
7678from nemo .collections .asr .parts .submodules .transducer_decoding .label_looping_base import (
7779 GreedyBatchedLabelLoopingComputerBase ,
9597)
9698from nemo .core .config import hydra_runner
9799from nemo .utils import logging
100+ from nemo .utils .timers import SimpleTimer
98101
99102
100103def make_divisible_by (num , factor : int ) -> int :
@@ -113,6 +116,7 @@ class TranscriptionConfig:
113116 pretrained_name : Optional [str ] = None # Name of a pretrained model
114117 audio_dir : Optional [str ] = None # Path to a directory which contains audio files
115118 dataset_manifest : Optional [str ] = None # Path to dataset's JSON manifest
119+ sort_by_duration : bool = True # sort manifest/audio files by duration (descending)
116120
117121 # General configs
118122 output_filename : Optional [str ] = None
@@ -145,6 +149,8 @@ class TranscriptionConfig:
145149
146150 # Decoding strategy for RNNT models
147151 decoding : RNNTDecodingConfig = field (default_factory = RNNTDecodingConfig )
152+ # Per-utterance biasing with biasing config in the manifest
153+ use_per_stream_biasing : bool = False
148154
149155 timestamps : bool = False # output timestamps
150156
@@ -154,6 +160,8 @@ class TranscriptionConfig:
154160 langid : str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
155161 use_cer : bool = False
156162
163+ calculate_rtfx : bool = False
164+
157165
158166@hydra_runner (config_name = "TranscriptionConfig" , schema = TranscriptionConfig )
159167def main (cfg : TranscriptionConfig ) -> TranscriptionConfig :
@@ -216,6 +224,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
216224 asr_model = asr_model .to (asr_model .device )
217225 asr_model .to (compute_dtype )
218226
227+ use_per_stream_biasing = cfg .use_per_stream_biasing
228+
219229 # Change Decoding Config
220230 with open_dict (cfg .decoding ):
221231 if cfg .decoding .strategy != "greedy_batch" or cfg .decoding .greedy .loop_labels is not True :
@@ -226,6 +236,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
226236 cfg .decoding .greedy .preserve_alignments = False
227237 cfg .decoding .fused_batch_size = - 1 # temporarily stop fused batch during inference.
228238 cfg .decoding .beam .return_best_hypothesis = True # return and write the best hypothsis only
239+ if use_per_stream_biasing :
240+ cfg .decoding .greedy .enable_per_stream_biasing = use_per_stream_biasing
229241
230242 # Setup decoding strategy
231243 if hasattr (asr_model , 'change_decoding_strategy' ):
@@ -250,6 +262,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
250262 assert filepaths is not None
251263 records = [{"audio_filepath" : audio_file } for audio_file in filepaths ]
252264
265+ if cfg .sort_by_duration :
266+ filepath2order = dict ()
267+ for i , record in enumerate (records ):
268+ if "duration" not in record :
269+ record ["duration" ] = librosa .get_duration (path = record ["audio_filepath" ])
270+ filepath2order [record ["audio_filepath" ]] = i
271+ records .sort (key = lambda record : record ["duration" ], reverse = True )
272+
253273 asr_model .preprocessor .featurizer .dither = 0.0
254274 asr_model .preprocessor .featurizer .pad_to = 0
255275 asr_model .eval ()
@@ -289,8 +309,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
289309 latency_secs = (context_samples .chunk + context_samples .right ) / audio_sample_rate
290310 logging .info (f"Theoretical latency: { latency_secs :.2f} seconds" )
291311
312+ biasing_requests : list [BiasingRequestItemConfig | None ] | None
313+ if use_per_stream_biasing :
314+ biasing_requests = [
315+ (
316+ BiasingRequestItemConfig (
317+ ** OmegaConf .to_container (
318+ OmegaConf .merge (OmegaConf .structured (BiasingRequestItemConfig ), record ["biasing_request" ])
319+ )
320+ )
321+ if "biasing_request" in record
322+ else None
323+ )
324+ for record in records
325+ ]
326+ else :
327+ biasing_requests = None
328+
292329 audio_dataset = SimpleAudioDataset (
293- audio_filenames = [record ["audio_filepath" ] for record in records ], sample_rate = audio_sample_rate
330+ audio_filenames = [record ["audio_filepath" ] for record in records ],
331+ sample_rate = audio_sample_rate ,
332+ biasing_requests = biasing_requests ,
294333 )
295334 audio_dataloader = DataLoader (
296335 dataset = audio_dataset ,
@@ -302,9 +341,11 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
302341 in_order = True ,
303342 )
304343
344+ timer = SimpleTimer ()
305345 with torch .no_grad (), torch .inference_mode ():
306346 all_hyps = []
307347 audio_data : AudioBatch
348+ timer .start (device = map_location )
308349 for audio_data in tqdm (audio_dataloader ):
309350 # get audio
310351 # NB: preprocessor runs on torch.float32, no need to cast dtype here
@@ -313,8 +354,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
313354 batch_size = audio_batch .shape [0 ]
314355 device = audio_batch .device
315356
316- # decode audio by chunks
357+ # add biasing requests to the decoder
358+ if use_per_stream_biasing :
359+ multi_biasing_ids = torch .full ([batch_size ], fill_value = - 1 , dtype = torch .long , device = map_location )
360+ if audio_data .biasing_requests is not None :
361+ for batch_i , request in enumerate (audio_data .biasing_requests ):
362+ if request is not None :
363+ biasing_model = request .get_model (tokenizer = asr_model .tokenizer )
364+ if biasing_model is not None :
365+ multi_model_id = decoding_computer .biasing_multi_model .add_model (biasing_model )
366+ request .multi_model_id = multi_model_id
367+ multi_biasing_ids [batch_i ] = multi_model_id
368+ else :
369+ multi_biasing_ids = None
317370
371+ # decode audio by chunks
318372 current_batched_hyps : BatchedHyps | None = None
319373 state = None
320374 left_sample = 0
@@ -368,6 +422,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
368422 encoder_context_batch .chunk ,
369423 ),
370424 prev_batched_state = state ,
425+ multi_biasing_ids = multi_biasing_ids ,
371426 )
372427 # merge hyps with previous hyps
373428 if current_batched_hyps is None :
@@ -380,7 +435,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
380435 left_sample = right_sample
381436 right_sample = min (right_sample + context_samples .chunk , audio_batch .shape [1 ]) # add next chunk
382437
438+ # remove biasing requests from the decoder
439+ if use_per_stream_biasing and audio_data .biasing_requests is not None :
440+ for request in audio_data .biasing_requests :
441+ if request is not None and request .multi_model_id is not None :
442+ decoding_computer .biasing_multi_model .remove_model (request .multi_model_id )
443+ request .multi_model_id = None
383444 all_hyps .extend (batched_hyps_to_hypotheses (current_batched_hyps , None , batch_size = batch_size ))
445+ timer .stop (device = map_location )
384446
385447 # convert text
386448 for i , hyp in enumerate (all_hyps ):
@@ -394,11 +456,26 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
394456 )
395457 all_hyps [i ] = hyp
396458
459+ if cfg .sort_by_duration :
460+ # restore order for all_hyps and records (all_hyps are consistent with records)
461+ order_restored = sorted (
462+ zip (records , all_hyps ), key = lambda records_hyps : filepath2order [records_hyps [0 ]["audio_filepath" ]]
463+ )
464+ records , all_hyps = map (list , zip (* order_restored ))
465+
397466 output_filename , pred_text_attr_name = write_transcription (
398467 all_hyps , cfg , model_name , filepaths = filepaths , compute_langs = False , timestamps = cfg .timestamps
399468 )
400469 logging .info (f"Finished writing predictions to { output_filename } !" )
401470
471+ if cfg .calculate_rtfx :
472+ durations = [
473+ record ["duration" ] if "duration" in record else librosa .get_duration (path = record ["audio_filepath" ])
474+ for record in records
475+ ]
476+ rtfx = sum (durations ) / timer .total_sec ()
477+ logging .info (f"RTFx: { rtfx :.2f} " )
478+
402479 if cfg .calculate_wer :
403480 output_manifest_w_wer , total_res , _ = cal_write_wer (
404481 pred_manifest = output_filename ,
0 commit comments