Skip to content

Commit cc0c3a1

Browse files
committed
Clean up
Signed-off-by: Vladimir Bataev <[email protected]>
1 parent a4e1582 commit cc0c3a1

File tree

1 file changed

+72
-57
lines changed

1 file changed

+72
-57
lines changed

nemo/collections/asr/parts/context_biasing/biasing_multi_model.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
@dataclass
4040
class BiasingRequestItemConfig:
4141
boosting_model_cfg: BoostingTreeModelConfig = field(default_factory=BoostingTreeModelConfig)
42-
boosting_model_alpha: float = 1.0
42+
boosting_model_alpha: float = 1.0 # boosting weight
4343
cache_key: str | None = None # cache key for memory cache; NB: cache key should be unique for (tokenizer, phrases)
4444
multi_model_id: int | None = None # compiled model id
45-
auto_manage_multi_model: bool = True
45+
auto_manage_multi_model: bool = True # if model should be added to the decoder and removed automatically
4646

4747
def is_empty(self) -> bool:
4848
"""Return True if biasing request (or model) is empty"""
@@ -91,16 +91,16 @@ class GPUBiasingMultiModelBase(abc.ABC, nn.Module):
9191

9292
@abstractmethod
9393
def add_model(self, model: NGramGPULanguageModel, alpha: float = 1.0) -> int:
94-
raise NotImplementedError
94+
pass
9595

9696
@abstractmethod
9797
def remove_model(self, model_id: int):
98-
raise NotImplementedError
98+
pass
9999

100100
@abstractmethod
101101
def has_models(self) -> bool:
102102
"""Return True if the multi-model has at least one model"""
103-
raise NotImplementedError
103+
pass
104104

105105
def compatible_with_cuda_graphs(self) -> bool:
106106
"""True if model can be compiled as a part of CUDA graph, False otherwise"""
@@ -140,7 +140,13 @@ def get_init_states(self, batch_size: int, bos=True) -> torch.Tensor:
140140
class GPUBiasingMultiModelReference(GPUBiasingMultiModelBase):
141141
"""Reference implementation (incompatible with CUDA graphs)"""
142142

143-
def __init__(self, vocab_size: int):
143+
def __init__(self, vocab_size: int, *args, **kwargs):
144+
"""
145+
146+
Args:
147+
vocab_size: vocabulary size of the model
148+
*args, **kwargs: added for easiness of switching between this model and efficient implementation
149+
"""
144150
super().__init__()
145151
self.models = nn.ModuleList([])
146152
self.buffer_for_device_handling = nn.Buffer(torch.zeros([1], dtype=torch.long))
@@ -251,9 +257,9 @@ def __init__(
251257
"""
252258
253259
Args:
254-
vocab_size:
255-
reallocation_callback_fn:
256-
use_triton:
260+
vocab_size: vocabulary size of the model
261+
reallocation_callback_fn: function to call when reallocation occurred (needed for decoders with CUDA graphs)
262+
use_triton: allow using Triton, `None` means "auto" (used if available)
257263
"""
258264
super().__init__()
259265
self.vocab_size: int = vocab_size
@@ -319,28 +325,34 @@ def _check_model_compatibility(self, model: NGramGPULanguageModel):
319325
if not model._final_resolved:
320326
model._resolve_final()
321327

328+
@staticmethod
329+
def _extend_buffer_or_param(buffer_or_param: nn.Buffer | nn.Parameter, add_len: int):
330+
"""Extend buffer or parameter"""
331+
buffer_or_param.data = torch.cat(
332+
(
333+
buffer_or_param.data,
334+
torch.zeros(
335+
[add_len] + list(buffer_or_param.shape)[1:],
336+
dtype=buffer_or_param.dtype,
337+
device=buffer_or_param.device,
338+
),
339+
)
340+
)
341+
322342
def _maybe_extend_arcs_and_states(self, add_num_states: int, add_num_arcs_extended: int) -> bool:
323343
"""Extend memory allocated for arcs and states, return True if any tensor is reallocated"""
324344
reallocated = False
325345

326-
def _extend_buffer_or_param(buffer: nn.Buffer | nn.Parameter, add_len: int):
327-
buffer.data = torch.cat(
328-
(
329-
buffer.data,
330-
torch.zeros([add_len] + list(buffer.shape)[1:], dtype=buffer.dtype, device=buffer.device),
331-
)
332-
)
333-
334346
if self.num_arcs_extended_total + add_num_arcs_extended > self.num_arcs_extended_reserved:
335347
# min allocation: 2x
336348
add_num_arcs = max(
337349
self.num_arcs_extended_reserved,
338350
self.num_arcs_extended_total + add_num_arcs_extended - self.num_arcs_extended_reserved,
339351
)
340-
_extend_buffer_or_param(self.all_arcs_weights, add_len=add_num_arcs)
341-
_extend_buffer_or_param(self.all_from_states, add_len=add_num_arcs)
342-
_extend_buffer_or_param(self.all_to_states, add_len=add_num_arcs)
343-
_extend_buffer_or_param(self.all_ilabels, add_len=add_num_arcs)
352+
self._extend_buffer_or_param(self.all_arcs_weights, add_len=add_num_arcs)
353+
self._extend_buffer_or_param(self.all_from_states, add_len=add_num_arcs)
354+
self._extend_buffer_or_param(self.all_to_states, add_len=add_num_arcs)
355+
self._extend_buffer_or_param(self.all_ilabels, add_len=add_num_arcs)
344356
self.num_arcs_extended_reserved += add_num_arcs
345357
reallocated = True
346358

@@ -349,31 +361,32 @@ def _extend_buffer_or_param(buffer: nn.Buffer | nn.Parameter, add_len: int):
349361
add_num_states = max(
350362
self.num_states_reserved, self.num_states_total + add_num_states - self.num_states_reserved
351363
)
352-
_extend_buffer_or_param(self.all_start_end_arcs, add_len=add_num_states)
353-
_extend_buffer_or_param(self.all_state_order, add_len=add_num_states)
354-
_extend_buffer_or_param(self.all_backoff_to_states, add_len=add_num_states)
355-
_extend_buffer_or_param(self.all_backoff_weights, add_len=add_num_states)
356-
_extend_buffer_or_param(self.all_final_weights, add_len=add_num_states)
364+
self._extend_buffer_or_param(self.all_start_end_arcs, add_len=add_num_states)
365+
self._extend_buffer_or_param(self.all_state_order, add_len=add_num_states)
366+
self._extend_buffer_or_param(self.all_backoff_to_states, add_len=add_num_states)
367+
self._extend_buffer_or_param(self.all_backoff_weights, add_len=add_num_states)
368+
self._extend_buffer_or_param(self.all_final_weights, add_len=add_num_states)
357369
self.num_states_reserved += add_num_states
358370
reallocated = True
359371

360372
return reallocated
361373

374+
@staticmethod
375+
def _extend_buffer_2x(buffer: nn.Buffer):
376+
buffer.data = torch.cat((buffer.data, torch.zeros_like(buffer.data)), dim=-1)
377+
362378
def _extend_num_models(self):
363379
"""Extend memory allocated for models with properties"""
364380
assert self.num_models_reserved > 0
365381
self.num_models_reserved *= 2
366382

367-
def _extend_buffer_2x(buffer: nn.Buffer):
368-
buffer.data = torch.cat((buffer.data, torch.zeros_like(buffer.data)), dim=-1)
369-
370-
_extend_buffer_2x(self.model2alpha)
371-
_extend_buffer_2x(self.model2active)
372-
_extend_buffer_2x(self.model2num_states)
373-
_extend_buffer_2x(self.model2num_arcs)
374-
_extend_buffer_2x(self.model2num_arcs_extended)
375-
_extend_buffer_2x(self.model2states_offset)
376-
_extend_buffer_2x(self.model2arcs_offset)
383+
self._extend_buffer_2x(self.model2alpha)
384+
self._extend_buffer_2x(self.model2active)
385+
self._extend_buffer_2x(self.model2num_states)
386+
self._extend_buffer_2x(self.model2num_arcs)
387+
self._extend_buffer_2x(self.model2num_arcs_extended)
388+
self._extend_buffer_2x(self.model2states_offset)
389+
self._extend_buffer_2x(self.model2arcs_offset)
377390

378391
@torch.no_grad()
379392
def add_model(self, model: GPUBoostingTreeModel, alpha: float = 1.0) -> int:
@@ -457,6 +470,16 @@ def add_model(self, model: GPUBoostingTreeModel, alpha: float = 1.0) -> int:
457470
reallocation_callback_fn()
458471
return model_id
459472

473+
@staticmethod
474+
def _clear_buffer_or_param_range(
475+
buffer_or_param: nn.Buffer | nn.Parameter, start: int, end: int, buffer_len: int | None = None
476+
):
477+
if buffer_len is None:
478+
buffer_len = buffer_or_param.shape[0]
479+
remove_len = end - start
480+
buffer_or_param[start : buffer_len - remove_len].copy_(buffer_or_param[end:buffer_len].clone())
481+
buffer_or_param[buffer_len - remove_len : buffer_len].fill_(0)
482+
460483
@torch.no_grad()
461484
def remove_model(self, model_id: int):
462485
"""
@@ -486,27 +509,18 @@ def remove_model(self, model_id: int):
486509

487510
assert num_arcs > 0 and num_states > 0, "Unexpected zero-size model"
488511

489-
def _clear_buffer_or_param_range(
490-
buffer: nn.Buffer | nn.Parameter, start: int, end: int, buffer_size: int | None = None
491-
):
492-
if buffer_size is None:
493-
buffer_size = buffer.shape[0]
494-
remove_len = end - start
495-
buffer[start : buffer_size - remove_len].copy_(buffer[end:buffer_size].clone())
496-
buffer[buffer_size - remove_len : buffer_size].fill_(0)
497-
498-
# clean up arcs-related data
499-
_clear_buffer_or_param_range(self.all_arcs_weights, start_arc, end_arc, self.num_arcs_extended_total)
500-
_clear_buffer_or_param_range(self.all_from_states, start_arc, end_arc, self.num_arcs_extended_total)
501-
_clear_buffer_or_param_range(self.all_to_states, start_arc, end_arc, self.num_arcs_extended_total)
502-
_clear_buffer_or_param_range(self.all_ilabels, start_arc, end_arc, self.num_arcs_extended_total)
503-
504-
# clean up states-related data
505-
_clear_buffer_or_param_range(self.all_start_end_arcs, start_state, end_state, self.num_states_total)
506-
_clear_buffer_or_param_range(self.all_state_order, start_state, end_state, self.num_states_total)
507-
_clear_buffer_or_param_range(self.all_backoff_to_states, start_state, end_state, self.num_states_total)
508-
_clear_buffer_or_param_range(self.all_backoff_weights, start_state, end_state, self.num_states_total)
509-
_clear_buffer_or_param_range(self.all_final_weights, start_state, end_state, self.num_states_total)
512+
# clean up arcs-related data: cut [start_arc, end_arc) from the buffer (shifting right part to the left)
513+
self._clear_buffer_or_param_range(self.all_arcs_weights, start_arc, end_arc, self.num_arcs_extended_total)
514+
self._clear_buffer_or_param_range(self.all_from_states, start_arc, end_arc, self.num_arcs_extended_total)
515+
self._clear_buffer_or_param_range(self.all_to_states, start_arc, end_arc, self.num_arcs_extended_total)
516+
self._clear_buffer_or_param_range(self.all_ilabels, start_arc, end_arc, self.num_arcs_extended_total)
517+
518+
# clean up states-related data: cut [start_state, end_state) from the buffer (shifting right part to the left)
519+
self._clear_buffer_or_param_range(self.all_start_end_arcs, start_state, end_state, self.num_states_total)
520+
self._clear_buffer_or_param_range(self.all_state_order, start_state, end_state, self.num_states_total)
521+
self._clear_buffer_or_param_range(self.all_backoff_to_states, start_state, end_state, self.num_states_total)
522+
self._clear_buffer_or_param_range(self.all_backoff_weights, start_state, end_state, self.num_states_total)
523+
self._clear_buffer_or_param_range(self.all_final_weights, start_state, end_state, self.num_states_total)
510524

511525
# set num states/arcs to zero
512526
self.num_states_total -= num_states
@@ -518,6 +532,7 @@ def _clear_buffer_or_param_range(
518532
# shift model offsets
519533
self.model2states_offset[model_id] = 0
520534
self.model2arcs_offset[model_id] = 0
535+
# shift states and arcs offsets
521536
torch.where(
522537
self.model2states_offset < start_state,
523538
self.model2states_offset,

0 commit comments

Comments
 (0)