3939@dataclass
4040class BiasingRequestItemConfig :
4141 boosting_model_cfg : BoostingTreeModelConfig = field (default_factory = BoostingTreeModelConfig )
42- boosting_model_alpha : float = 1.0
42+ boosting_model_alpha : float = 1.0 # boosting weight
4343 cache_key : str | None = None # cache key for memory cache; NB: cache key should be unique for (tokenizer, phrases)
4444 multi_model_id : int | None = None # compiled model id
45- auto_manage_multi_model : bool = True
45+ auto_manage_multi_model : bool = True # if model should be added to the decoder and removed automatically
4646
4747 def is_empty (self ) -> bool :
4848 """Return True if biasing request (or model) is empty"""
@@ -91,16 +91,16 @@ class GPUBiasingMultiModelBase(abc.ABC, nn.Module):
9191
9292 @abstractmethod
9393 def add_model (self , model : NGramGPULanguageModel , alpha : float = 1.0 ) -> int :
94- raise NotImplementedError
94+ pass
9595
9696 @abstractmethod
9797 def remove_model (self , model_id : int ):
98- raise NotImplementedError
98+ pass
9999
100100 @abstractmethod
101101 def has_models (self ) -> bool :
102102 """Return True if the multi-model has at least one model"""
103- raise NotImplementedError
103+ pass
104104
105105 def compatible_with_cuda_graphs (self ) -> bool :
106106 """True if model can be compiled as a part of CUDA graph, False otherwise"""
@@ -140,7 +140,13 @@ def get_init_states(self, batch_size: int, bos=True) -> torch.Tensor:
140140class GPUBiasingMultiModelReference (GPUBiasingMultiModelBase ):
141141 """Reference implementation (incompatible with CUDA graphs)"""
142142
143- def __init__ (self , vocab_size : int ):
143+ def __init__ (self , vocab_size : int , * args , ** kwargs ):
144+ """
145+
146+ Args:
147+ vocab_size: vocabulary size of the model
148+ *args, **kwargs: added for easiness of switching between this model and efficient implementation
149+ """
144150 super ().__init__ ()
145151 self .models = nn .ModuleList ([])
146152 self .buffer_for_device_handling = nn .Buffer (torch .zeros ([1 ], dtype = torch .long ))
@@ -251,9 +257,9 @@ def __init__(
251257 """
252258
253259 Args:
254- vocab_size:
255- reallocation_callback_fn:
256- use_triton:
260+ vocab_size: vocabulary size of the model
261+ reallocation_callback_fn: function to call when reallocation occurred (needed for decoders with CUDA graphs)
262+ use_triton: allow using Triton, `None` means "auto" (used if available)
257263 """
258264 super ().__init__ ()
259265 self .vocab_size : int = vocab_size
@@ -319,28 +325,34 @@ def _check_model_compatibility(self, model: NGramGPULanguageModel):
319325 if not model ._final_resolved :
320326 model ._resolve_final ()
321327
328+ @staticmethod
329+ def _extend_buffer_or_param (buffer_or_param : nn .Buffer | nn .Parameter , add_len : int ):
330+ """Extend buffer or parameter"""
331+ buffer_or_param .data = torch .cat (
332+ (
333+ buffer_or_param .data ,
334+ torch .zeros (
335+ [add_len ] + list (buffer_or_param .shape )[1 :],
336+ dtype = buffer_or_param .dtype ,
337+ device = buffer_or_param .device ,
338+ ),
339+ )
340+ )
341+
322342 def _maybe_extend_arcs_and_states (self , add_num_states : int , add_num_arcs_extended : int ) -> bool :
323343 """Extend memory allocated for arcs and states, return True if any tensor is reallocated"""
324344 reallocated = False
325345
326- def _extend_buffer_or_param (buffer : nn .Buffer | nn .Parameter , add_len : int ):
327- buffer .data = torch .cat (
328- (
329- buffer .data ,
330- torch .zeros ([add_len ] + list (buffer .shape )[1 :], dtype = buffer .dtype , device = buffer .device ),
331- )
332- )
333-
334346 if self .num_arcs_extended_total + add_num_arcs_extended > self .num_arcs_extended_reserved :
335347 # min allocation: 2x
336348 add_num_arcs = max (
337349 self .num_arcs_extended_reserved ,
338350 self .num_arcs_extended_total + add_num_arcs_extended - self .num_arcs_extended_reserved ,
339351 )
340- _extend_buffer_or_param (self .all_arcs_weights , add_len = add_num_arcs )
341- _extend_buffer_or_param (self .all_from_states , add_len = add_num_arcs )
342- _extend_buffer_or_param (self .all_to_states , add_len = add_num_arcs )
343- _extend_buffer_or_param (self .all_ilabels , add_len = add_num_arcs )
352+ self . _extend_buffer_or_param (self .all_arcs_weights , add_len = add_num_arcs )
353+ self . _extend_buffer_or_param (self .all_from_states , add_len = add_num_arcs )
354+ self . _extend_buffer_or_param (self .all_to_states , add_len = add_num_arcs )
355+ self . _extend_buffer_or_param (self .all_ilabels , add_len = add_num_arcs )
344356 self .num_arcs_extended_reserved += add_num_arcs
345357 reallocated = True
346358
@@ -349,31 +361,32 @@ def _extend_buffer_or_param(buffer: nn.Buffer | nn.Parameter, add_len: int):
349361 add_num_states = max (
350362 self .num_states_reserved , self .num_states_total + add_num_states - self .num_states_reserved
351363 )
352- _extend_buffer_or_param (self .all_start_end_arcs , add_len = add_num_states )
353- _extend_buffer_or_param (self .all_state_order , add_len = add_num_states )
354- _extend_buffer_or_param (self .all_backoff_to_states , add_len = add_num_states )
355- _extend_buffer_or_param (self .all_backoff_weights , add_len = add_num_states )
356- _extend_buffer_or_param (self .all_final_weights , add_len = add_num_states )
364+ self . _extend_buffer_or_param (self .all_start_end_arcs , add_len = add_num_states )
365+ self . _extend_buffer_or_param (self .all_state_order , add_len = add_num_states )
366+ self . _extend_buffer_or_param (self .all_backoff_to_states , add_len = add_num_states )
367+ self . _extend_buffer_or_param (self .all_backoff_weights , add_len = add_num_states )
368+ self . _extend_buffer_or_param (self .all_final_weights , add_len = add_num_states )
357369 self .num_states_reserved += add_num_states
358370 reallocated = True
359371
360372 return reallocated
361373
374+ @staticmethod
375+ def _extend_buffer_2x (buffer : nn .Buffer ):
376+ buffer .data = torch .cat ((buffer .data , torch .zeros_like (buffer .data )), dim = - 1 )
377+
362378 def _extend_num_models (self ):
363379 """Extend memory allocated for models with properties"""
364380 assert self .num_models_reserved > 0
365381 self .num_models_reserved *= 2
366382
367- def _extend_buffer_2x (buffer : nn .Buffer ):
368- buffer .data = torch .cat ((buffer .data , torch .zeros_like (buffer .data )), dim = - 1 )
369-
370- _extend_buffer_2x (self .model2alpha )
371- _extend_buffer_2x (self .model2active )
372- _extend_buffer_2x (self .model2num_states )
373- _extend_buffer_2x (self .model2num_arcs )
374- _extend_buffer_2x (self .model2num_arcs_extended )
375- _extend_buffer_2x (self .model2states_offset )
376- _extend_buffer_2x (self .model2arcs_offset )
383+ self ._extend_buffer_2x (self .model2alpha )
384+ self ._extend_buffer_2x (self .model2active )
385+ self ._extend_buffer_2x (self .model2num_states )
386+ self ._extend_buffer_2x (self .model2num_arcs )
387+ self ._extend_buffer_2x (self .model2num_arcs_extended )
388+ self ._extend_buffer_2x (self .model2states_offset )
389+ self ._extend_buffer_2x (self .model2arcs_offset )
377390
378391 @torch .no_grad ()
379392 def add_model (self , model : GPUBoostingTreeModel , alpha : float = 1.0 ) -> int :
@@ -457,6 +470,16 @@ def add_model(self, model: GPUBoostingTreeModel, alpha: float = 1.0) -> int:
457470 reallocation_callback_fn ()
458471 return model_id
459472
473+ @staticmethod
474+ def _clear_buffer_or_param_range (
475+ buffer_or_param : nn .Buffer | nn .Parameter , start : int , end : int , buffer_len : int | None = None
476+ ):
477+ if buffer_len is None :
478+ buffer_len = buffer_or_param .shape [0 ]
479+ remove_len = end - start
480+ buffer_or_param [start : buffer_len - remove_len ].copy_ (buffer_or_param [end :buffer_len ].clone ())
481+ buffer_or_param [buffer_len - remove_len : buffer_len ].fill_ (0 )
482+
460483 @torch .no_grad ()
461484 def remove_model (self , model_id : int ):
462485 """
@@ -486,27 +509,18 @@ def remove_model(self, model_id: int):
486509
487510 assert num_arcs > 0 and num_states > 0 , "Unexpected zero-size model"
488511
489- def _clear_buffer_or_param_range (
490- buffer : nn .Buffer | nn .Parameter , start : int , end : int , buffer_size : int | None = None
491- ):
492- if buffer_size is None :
493- buffer_size = buffer .shape [0 ]
494- remove_len = end - start
495- buffer [start : buffer_size - remove_len ].copy_ (buffer [end :buffer_size ].clone ())
496- buffer [buffer_size - remove_len : buffer_size ].fill_ (0 )
497-
498- # clean up arcs-related data
499- _clear_buffer_or_param_range (self .all_arcs_weights , start_arc , end_arc , self .num_arcs_extended_total )
500- _clear_buffer_or_param_range (self .all_from_states , start_arc , end_arc , self .num_arcs_extended_total )
501- _clear_buffer_or_param_range (self .all_to_states , start_arc , end_arc , self .num_arcs_extended_total )
502- _clear_buffer_or_param_range (self .all_ilabels , start_arc , end_arc , self .num_arcs_extended_total )
503-
504- # clean up states-related data
505- _clear_buffer_or_param_range (self .all_start_end_arcs , start_state , end_state , self .num_states_total )
506- _clear_buffer_or_param_range (self .all_state_order , start_state , end_state , self .num_states_total )
507- _clear_buffer_or_param_range (self .all_backoff_to_states , start_state , end_state , self .num_states_total )
508- _clear_buffer_or_param_range (self .all_backoff_weights , start_state , end_state , self .num_states_total )
509- _clear_buffer_or_param_range (self .all_final_weights , start_state , end_state , self .num_states_total )
512+ # clean up arcs-related data: cut [start_arc, end_arc) from the buffer (shifting right part to the left)
513+ self ._clear_buffer_or_param_range (self .all_arcs_weights , start_arc , end_arc , self .num_arcs_extended_total )
514+ self ._clear_buffer_or_param_range (self .all_from_states , start_arc , end_arc , self .num_arcs_extended_total )
515+ self ._clear_buffer_or_param_range (self .all_to_states , start_arc , end_arc , self .num_arcs_extended_total )
516+ self ._clear_buffer_or_param_range (self .all_ilabels , start_arc , end_arc , self .num_arcs_extended_total )
517+
518+ # clean up states-related data: cut [start_state, end_state) from the buffer (shifting right part to the left)
519+ self ._clear_buffer_or_param_range (self .all_start_end_arcs , start_state , end_state , self .num_states_total )
520+ self ._clear_buffer_or_param_range (self .all_state_order , start_state , end_state , self .num_states_total )
521+ self ._clear_buffer_or_param_range (self .all_backoff_to_states , start_state , end_state , self .num_states_total )
522+ self ._clear_buffer_or_param_range (self .all_backoff_weights , start_state , end_state , self .num_states_total )
523+ self ._clear_buffer_or_param_range (self .all_final_weights , start_state , end_state , self .num_states_total )
510524
511525 # set num states/arcs to zero
512526 self .num_states_total -= num_states
@@ -518,6 +532,7 @@ def _clear_buffer_or_param_range(
518532 # shift model offsets
519533 self .model2states_offset [model_id ] = 0
520534 self .model2arcs_offset [model_id ] = 0
535+ # shift states and arcs offsets
521536 torch .where (
522537 self .model2states_offset < start_state ,
523538 self .model2states_offset ,
0 commit comments