Skip to content

Commit 8e9a986

Browse files
committed
Merge branch 'main' into vbataev/multi_biasing_models
# Conflicts: # nemo/collections/asr/inference/streaming/framing/request_options.py
2 parents fad3cad + af95c29 commit 8e9a986

File tree

9 files changed

+258
-45
lines changed

9 files changed

+258
-45
lines changed

nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def execute_step(
8383
keep_all_outputs: bool,
8484
drop_left_context: int | None = None,
8585
valid_out_len: int | None = None,
86+
prompt_vectors: Tensor | None = None,
8687
) -> tuple[list[Hypothesis], CacheAwareContext]:
8788
"""
8889
Executes a single streaming step.
@@ -95,6 +96,7 @@ def execute_step(
9596
keep_all_outputs: (bool) whether to keep all outputs or not.
9697
drop_left_context: (int | None) number of left context frames to drop.
9798
valid_out_len: (int | None) number of valid output frames.
99+
prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts].
98100
Returns:
99101
(tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context.
100102
"""
@@ -144,6 +146,7 @@ def stream_step(
144146
keep_all_outputs: bool = False,
145147
drop_left_context: int | None = None,
146148
valid_out_len: int | None = None,
149+
prompt_vectors: Tensor | None = None,
147150
) -> tuple[list[Hypothesis], CacheAwareContext]:
148151
"""
149152
Executes a single streaming step.
@@ -156,6 +159,7 @@ def stream_step(
156159
keep_all_outputs: (bool) whether to keep all outputs or not.
157160
drop_left_context: (int | None) number of left context frames to drop.
158161
valid_out_len: (int | None) number of valid output frames.
162+
prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts].
159163
Returns:
160164
(tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context.
161165
"""
@@ -185,6 +189,7 @@ def stream_step(
185189
keep_all_outputs,
186190
drop_left_context,
187191
valid_out_len,
192+
prompt_vectors,
188193
)
189194

190195
return best_hyp, new_context

nemo/collections/asr/inference/model_wrappers/rnnt_inference_wrapper.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,16 @@ def get_subsampling_factor(self) -> int:
7171
"""
7272
return self.asr_model.encoder.subsampling_factor
7373

74-
def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> tuple[Tensor, Tensor]:
74+
def encode(
75+
self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor | None = None
76+
) -> tuple[Tensor, Tensor]:
7577
"""
7678
Get encoder output from the model. It is used for streaming inference.
7779
Args:
7880
processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]).
7981
processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]).
82+
prompt_vectors: (Tensor | None) Optional prompt vectors for multilingual models.
83+
Shape can be torch.Size([B, num_prompts]) or torch.Size([B, T_enc, num_prompts]) if already expanded.
8084
Returns:
8185
(tuple[Tensor, Tensor]) encoder output and encoder output length of shape torch.Size([B, T, D]), torch.Size([B]).
8286
"""
@@ -92,9 +96,15 @@ def encode(self, processed_signal: Tensor, processed_signal_length: Tensor) -> t
9296
torch.no_grad(),
9397
):
9498

95-
forward_outs = self.asr_model(
96-
processed_signal=processed_signal.to(self.cast_dtype), processed_signal_length=processed_signal_length
97-
)
99+
# Prepare model arguments
100+
model_args = {
101+
'processed_signal': processed_signal.to(self.cast_dtype),
102+
'processed_signal_length': processed_signal_length,
103+
}
104+
if prompt_vectors is not None:
105+
model_args['prompt'] = prompt_vectors
106+
107+
forward_outs = self.asr_model(**model_args)
98108

99109
encoded, encoded_len = forward_outs
100110
return encoded, encoded_len
@@ -113,3 +123,25 @@ def decode(self, encoded: Tensor, encoded_len: Tensor, partial_hypotheses: list)
113123
encoded.to(self.cast_dtype), encoded_len, return_hypotheses=True, partial_hypotheses=partial_hypotheses
114124
)
115125
return best_hyp
126+
127+
def encode_with_prompts(
128+
self, processed_signal: Tensor, processed_signal_length: Tensor, prompt_vectors: Tensor
129+
) -> tuple[Tensor, Tensor]:
130+
"""
131+
Convenience wrapper for prompt-enabled encoding.
132+
Expands prompt vectors across the time dimension before calling encode.
133+
Args:
134+
processed_signal: (Tensor) processed signal. Shape is torch.Size([B, C, T]).
135+
processed_signal_length: (Tensor) processed signal length. Shape is torch.Size([B]).
136+
prompt_vectors: (Tensor) prompt vectors. Shape is torch.Size([B, num_prompts]).
137+
Returns:
138+
(tuple[Tensor, Tensor]) encoder output and encoder output length.
139+
"""
140+
encoder_time_steps = processed_signal.shape[2] // self.get_subsampling_factor()
141+
# Expand prompts: [B, num_prompts] -> [B, T_enc, num_prompts]
142+
prompt_vectors = prompt_vectors.unsqueeze(1).expand(-1, encoder_time_steps, -1)
143+
return self.encode(
144+
processed_signal=processed_signal,
145+
processed_signal_length=processed_signal_length,
146+
prompt_vectors=prompt_vectors,
147+
)

nemo/collections/asr/inference/pipelines/base_pipeline.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from dataclasses import dataclass
2222
from typing import TYPE_CHECKING, Iterable
2323

24+
import torch
2425
from omegaconf import DictConfig
26+
from torch import Tensor
2527

2628
from nemo.collections.asr.inference.model_wrappers.asr_inference_wrapper import ASRInferenceWrapper
2729
from nemo.collections.asr.inference.pipelines.pipeline_interface import PipelineInterface
@@ -481,6 +483,90 @@ def init_context_manager(self) -> None:
481483
cache_aware_model=self.asr_model, num_slots=self.num_slots, use_cache=self.use_cache
482484
)
483485

486+
def init_prompt_support(self) -> None:
487+
"""Initialize prompt support for multilingual models."""
488+
self.prompt_enabled = hasattr(self.asr_model.asr_model, 'concat') and self.asr_model.asr_model.concat
489+
490+
if self.prompt_enabled:
491+
self._prompt_config = self._load_prompt_config()
492+
493+
def _load_prompt_config(self) -> dict:
494+
"""
495+
Load prompt configuration from model.
496+
Returns:
497+
(dict) Prompt configuration containing num_prompts, prompt_dict, and compute_dtype.
498+
"""
499+
cfg = self.asr_model.asr_model.cfg
500+
if cfg and hasattr(cfg, 'model_defaults'):
501+
model_defaults = cfg.model_defaults
502+
num_prompts = model_defaults.get('num_prompts', None)
503+
prompt_dict = model_defaults.get('prompt_dictionary', None)
504+
505+
# Validate and convert types once
506+
num_prompts_int = int(num_prompts) if num_prompts is not None else 0
507+
508+
is_dict_like = isinstance(prompt_dict, dict) or (
509+
hasattr(prompt_dict, 'get') and hasattr(prompt_dict, '__contains__')
510+
)
511+
512+
if num_prompts_int > 0 and is_dict_like:
513+
return {
514+
'num_prompts': num_prompts_int,
515+
'prompt_dict': prompt_dict,
516+
'compute_dtype': getattr(self.asr_model.asr_model, 'dtype', torch.float32),
517+
}
518+
519+
return {}
520+
521+
def _resolve_prompt_index(self, language_code: str) -> int:
522+
"""
523+
Resolve language_code to a strict prompt index; raise if invalid.
524+
Args:
525+
language_code: (str) Language code to resolve (e.g., "en-US", "es-ES").
526+
Returns:
527+
(int) Prompt index corresponding to the language code.
528+
Raises:
529+
RuntimeError: If prompt configuration is missing.
530+
ValueError: If language_code is not found in prompt dictionary.
531+
"""
532+
if not hasattr(self, '_prompt_config') or not self._prompt_config:
533+
raise RuntimeError("Prompt configuration is missing for a prompt-enabled model.")
534+
prompt_dict = self._prompt_config['prompt_dict']
535+
lang_index = prompt_dict.get(language_code, None)
536+
if lang_index is None:
537+
raise ValueError(
538+
f"Language code '{language_code}' not found in prompt dictionary. "
539+
f"Available languages: {list(prompt_dict.keys())}"
540+
)
541+
return lang_index
542+
543+
def _create_one_hot_prompts(self, indices: Tensor) -> Tensor:
544+
"""
545+
Create one-hot prompt vectors from indices.
546+
Args:
547+
indices: (Tensor) Prompt indices of shape [B].
548+
Returns:
549+
(Tensor) One-hot prompt vectors of shape [B, num_prompts].
550+
"""
551+
num_prompts = self._prompt_config['num_prompts']
552+
return torch.nn.functional.one_hot(indices, num_classes=num_prompts).to(self._prompt_config['compute_dtype'])
553+
554+
def _build_prompt_vectors(self, states: list) -> Tensor:
555+
"""
556+
Build prompt vectors for a batch of states using one-hot encoding.
557+
Args:
558+
states: (list) List of streaming states.
559+
Returns:
560+
(Tensor) Prompt vectors of shape [B, num_prompts].
561+
Raises:
562+
ValueError: If any prompt index is out of range.
563+
"""
564+
indices = torch.tensor([getattr(s, 'prompt_idx', 0) for s in states], device=self.device, dtype=torch.long)
565+
num_prompts = self._prompt_config['num_prompts']
566+
if torch.any((indices < 0) | (indices >= num_prompts)):
567+
raise ValueError("Found out-of-range prompt index in batch.")
568+
return self._create_one_hot_prompts(indices)
569+
484570
def run(
485571
self,
486572
audio_filepaths: list[str],

nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
"""
6868

6969
self.copy_asr_model_attributes(asr_model)
70+
self.init_prompt_support()
7071
self.init_parameters(cfg)
7172
self.init_bufferer_for_buffered_streaming()
7273
self.conf_func, self.confidence_aggregator = get_confidence_utils(cfg.confidence)
@@ -196,9 +197,24 @@ def init_zero_enc(self) -> Tensor:
196197
buffer_lens=torch.tensor([zero_buffer.shape[1]], device=self.device),
197198
expected_feature_buffer_len=self.expected_feature_buffer_len,
198199
)
199-
zero_encoded, _ = self.asr_model.encode(
200-
processed_signal=zero_features, processed_signal_length=zero_features_len
201-
)
200+
201+
if self.prompt_enabled:
202+
# Use "en-US" as the default prompt for zero encoding
203+
# This region is sliced out before decoding, so language choice doesn't matter
204+
default_prompt_idx = self._resolve_prompt_index("en-US")
205+
prompt_indices = torch.tensor([default_prompt_idx], device=self.device, dtype=torch.long)
206+
prompt_vector = self._create_one_hot_prompts(prompt_indices) # [1, num_prompts]
207+
208+
zero_encoded, _ = self.asr_model.encode_with_prompts(
209+
processed_signal=zero_features,
210+
processed_signal_length=zero_features_len,
211+
prompt_vectors=prompt_vector,
212+
)
213+
else:
214+
zero_encoded, _ = self.asr_model.encode(
215+
processed_signal=zero_features, processed_signal_length=zero_features_len
216+
)
217+
202218
return zero_encoded[0]
203219

204220
def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:
@@ -219,8 +235,18 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:
219235
default_target_language=self.nmt_model.target_language if self.nmt_enabled else None,
220236
default_stop_history_eou=self.stop_history_eou_in_milliseconds,
221237
default_asr_output_granularity=self.asr_output_granularity,
238+
default_language_code="en-US" if self.prompt_enabled else None,
222239
)
223240
state.set_options(new_options)
241+
242+
# Create per-stream prompt index for prompt-enabled models
243+
if self.prompt_enabled:
244+
lang_code = getattr(new_options, "language_code", None)
245+
if not isinstance(lang_code, str) or len(lang_code) == 0:
246+
raise ValueError("Prompt-enabled model requires a valid language_code in request options.")
247+
prompt_idx = self._resolve_prompt_index(lang_code)
248+
state.set_prompt_index(prompt_idx)
249+
224250
return state
225251

226252
def get_sep(self) -> str:
@@ -304,9 +330,21 @@ def encode_raw_signals(
304330
expected_feature_buffer_len=self.expected_feature_buffer_len,
305331
)
306332

307-
encoded, encoded_len = self.asr_model.encode(
308-
processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens
309-
)
333+
# Build prompt vectors if prompts are enabled
334+
if self.prompt_enabled:
335+
requests_states = [self.get_state(f.stream_id) for f in frames]
336+
prompt_vectors = self._build_prompt_vectors(requests_states)
337+
338+
# Use encode_with_prompts which handles dimension expansion
339+
encoded, encoded_len = self.asr_model.encode_with_prompts(
340+
processed_signal=feature_buffers,
341+
processed_signal_length=feature_buffer_lens,
342+
prompt_vectors=prompt_vectors,
343+
)
344+
else:
345+
encoded, encoded_len = self.asr_model.encode(
346+
processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens
347+
)
310348
encoded = encoded.clone()
311349
encoded_len = encoded_len.clone()
312350

@@ -340,9 +378,21 @@ def encode_processed_signals(
340378
processed_signals = normalize_features(processed_signals, processed_signal_lengths)
341379
processed_signal_lengths = processed_signal_lengths.clamp(max=processed_signals.shape[2])
342380

343-
encoded, encoded_len = self.asr_model.encode(
344-
processed_signal=processed_signals, processed_signal_length=processed_signal_lengths
345-
)
381+
# Build prompt vectors if prompts are enabled
382+
if self.prompt_enabled:
383+
requests_states = [self.get_state(f.stream_id) for f in fbuffers]
384+
prompt_vectors = self._build_prompt_vectors(requests_states)
385+
386+
# Use encode_with_prompts which handles dimension expansion
387+
encoded, encoded_len = self.asr_model.encode_with_prompts(
388+
processed_signal=processed_signals,
389+
processed_signal_length=processed_signal_lengths,
390+
prompt_vectors=prompt_vectors,
391+
)
392+
else:
393+
encoded, encoded_len = self.asr_model.encode(
394+
processed_signal=processed_signals, processed_signal_length=processed_signal_lengths
395+
)
346396
encoded = encoded.clone()
347397
encoded_len = encoded_len.clone()
348398

nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
nmt_model: (LLMTranslator | None) LLM based translation model.
6565
"""
6666
self.copy_asr_model_attributes(asr_model)
67+
self.init_prompt_support()
6768
self.init_parameters(cfg)
6869
self.init_context_manager()
6970
self.init_bufferer_for_cache_aware_streaming()
@@ -187,6 +188,7 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta
187188
default_target_language=self.nmt_model.target_language if self.nmt_enabled else None,
188189
default_stop_history_eou=self.stop_history_eou_in_milliseconds,
189190
default_asr_output_granularity=self.asr_output_granularity,
191+
default_language_code="en-US" if self.prompt_enabled else None,
190192
)
191193

192194
eou_label_buffer_size = 0
@@ -198,6 +200,15 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta
198200
state.setup_label_buffer(eou_label_buffer_size, self.blank_id)
199201
state.set_previous_hypothesis(None)
200202
state.set_options(new_options)
203+
204+
# Create per-stream prompt index for prompt-enabled models
205+
if self.prompt_enabled:
206+
lang_code = getattr(new_options, "language_code", None)
207+
if not isinstance(lang_code, str) or len(lang_code) == 0:
208+
raise ValueError("Prompt-enabled model requires a valid language_code in request options.")
209+
prompt_idx = self._resolve_prompt_index(lang_code)
210+
state.set_prompt_index(prompt_idx)
211+
201212
return state
202213

203214
def get_sep(self) -> str:
@@ -291,6 +302,10 @@ def cache_aware_transcribe_step(
291302
previous_hypotheses = [state.get_previous_hypothesis() for state in states]
292303
context, mapping = self.context_manager.get_context(stream_ids)
293304

305+
prompt_vectors = None
306+
if self.prompt_enabled:
307+
prompt_vectors = self._build_prompt_vectors(states)
308+
294309
drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded
295310
best_hyp, new_context = self.asr_model.stream_step(
296311
processed_signal=feature_buffers,
@@ -301,6 +316,7 @@ def cache_aware_transcribe_step(
301316
keep_all_outputs=keep_all_outputs,
302317
drop_left_context=self.drop_left_context,
303318
valid_out_len=self.valid_out_len,
319+
prompt_vectors=prompt_vectors,
304320
)
305321

306322
# update the cache and reset the cache slots for the streams that has ended

nemo/collections/asr/inference/streaming/framing/request_options.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ASRRequestOptions:
3131
enable_pnc: bool = None
3232
stop_history_eou: int = None
3333
asr_output_granularity: ASROutputGranularity | str = None
34+
language_code: str | None = None
3435
enable_nmt: bool = None
3536
source_language: str = None
3637
target_language: str = None
@@ -82,6 +83,7 @@ def augment_with_defaults(
8283
default_target_language: str,
8384
default_stop_history_eou: int,
8485
default_asr_output_granularity: ASROutputGranularity | str,
86+
default_language_code: str | None = None,
8587
biasing_cfg: BiasingRequestItemConfig | None = None,
8688
) -> "ASRRequestOptions":
8789
"""
@@ -94,6 +96,7 @@ def augment_with_defaults(
9496
default_target_language (str): Default target language.
9597
default_stop_history_eou (int): Default stop history EOU.
9698
default_asr_output_granularity (ASROutputGranularity | str): Default output granularity.
99+
default_language_code (str | None): Default language code for prompt-enabled models.
97100
biasing_cfg: Default biasing config or None
98101
Returns:
99102
ASRRequestOptions: Augmented options.
@@ -113,6 +116,7 @@ def augment_with_defaults(
113116

114117
stop_history_eou = self._with_default(self.stop_history_eou, default_stop_history_eou)
115118
granularity = self._with_default(self.asr_output_granularity, default_asr_output_granularity)
119+
language_code = self._with_default(self.language_code, default_language_code)
116120

117121
return ASRRequestOptions(
118122
enable_itn=enable_itn,
@@ -122,6 +126,7 @@ def augment_with_defaults(
122126
target_language=target_language,
123127
stop_history_eou=stop_history_eou,
124128
asr_output_granularity=granularity,
129+
language_code=language_code,
125130
biasing_cfg=self.biasing_cfg or biasing_cfg,
126131
)
127132

0 commit comments

Comments
 (0)