2323import logging
2424import os
2525from datetime import timedelta
26- from typing import Optional , Tuple , Union
26+ from typing import Dict , Optional , Tuple , Union
2727
2828import torch
2929import torch .nn .functional as F
4141 BitsAndBytesConfig ,
4242 PretrainedConfig ,
4343)
44+ from transformers .generation .configuration_utils import GenerationConfig
4445from transformers .generation .utils import GenerateOutput
4546from transformers .models .auto .modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4647
@@ -108,6 +109,8 @@ class TransformersModelConfig(ModelConfig):
108109 True forces adding space, False removes leading space if present.
109110 pairwise_tokenization (bool):
110111 Whether to tokenize context and continuation separately or together. Defaults to False.
112+ continuous_batching (bool):
113+ Whether to use continuous batching for generation. Defaults to False.
111114
112115 Example:
113116 ```python
@@ -143,6 +146,7 @@ class TransformersModelConfig(ModelConfig):
143146 compile : bool = False
144147 multichoice_continuations_start_space : bool | None = None
145148 pairwise_tokenization : bool = False
149+ continuous_batching : bool = False
146150
147151 def model_post_init (self , __context ):
148152 if self .multichoice_continuations_start_space is True :
@@ -185,7 +189,9 @@ def __init__(
185189 self ._add_special_tokens = config .add_special_tokens or False
186190 self .pairwise_tokenization = config .pairwise_tokenization
187191 self .batch_size = config .batch_size
192+ self .continuous_batching = config .continuous_batching
188193 self .transformers_config = config .get_transformers_config ()
194+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
189195
190196 self .model_sha = config .get_model_sha ()
191197 self ._max_length = self ._init_max_length ()
@@ -206,8 +212,6 @@ def __init__(
206212
207213 self .model_name = _simplify_name (config .model_name )
208214
209- self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
210-
211215 if is_accelerate_available ():
212216 model_size , _ = calculate_maximum_sizes (self .model )
213217 model_size = convert_bytes (model_size )
@@ -252,14 +256,15 @@ def from_model(
252256
253257 # Instanciate the object without using __init__
254258 self = cls .__new__ (cls )
255- self .config = config
256259 self .transformers_config = model .config
257- self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
260+ self .config = config if config is not None else TransformersModelConfig (model_name = model .config .name_or_path )
261+ if config is not None :
262+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
258263 self ._max_length = self ._init_max_length ()
259264 self ._tokenizer = self ._create_auto_tokenizer ()
260- self .batch_size = config . batch_size
265+ self .batch_size = getattr ( config , " batch_size" , None )
261266 self .model_name = _simplify_name (model .name_or_path )
262- self .model_sha = config .get_model_sha ()
267+ self .model_sha = self . config .get_model_sha ()
263268
264269 # If model_parallel is not set we compare the number of processes with the number of GPUs
265270 self .model = model
@@ -398,6 +403,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
398403 # model.to(self.device)
399404 model .eval ()
400405 torch .set_grad_enabled (False )
406+ if self .continuous_batching :
407+ generation_config = GenerationConfig (
408+ ** self .generation_config_dict ,
409+ )
410+ model .generation_config = generation_config
401411
402412 if self .config .compile :
403413 try :
@@ -500,7 +510,110 @@ def forward_batch(batch_size):
500510 logger .info (f"Determined largest batch size: { batch_size } " )
501511 return batch_size
502512
503- def greedy_until (
513+ def _continuous_greedy_until (
514+ self ,
515+ docs : list [Doc ],
516+ ) -> list [ModelResponse ]:
517+ """
518+ Generates responses using a greedy decoding strategy until certain ending conditions are met.
519+
520+ Args:
521+ requests (list[Request]): list of requests containing the context and ending conditions.
522+ override_bs (int, optional): Override the batch size for generation. Defaults to None.
523+
524+ Returns:
525+ list[GenerateReturn]: list of generated responses.
526+ """
527+ dataset = GenerativeTaskDataset (requests = docs , num_dataset_splits = self .DATASET_SPLITS )
528+ results = []
529+
530+ for split in tqdm (
531+ dataset .splits_iterator (),
532+ total = dataset .num_dataset_splits ,
533+ desc = "Splits" ,
534+ position = 0 ,
535+ disable = False , # self.disable_tqdm,
536+ ):
537+ # For chat models, generation stops with EOS token, so we don't need to specify stop tokens
538+ if self .use_chat_template :
539+ stop_tokens = []
540+ else :
541+ # NOTE: we are assuming all items in a batch behave similarly (same
542+ # stop_tokens and max_tokens genrated) which is not necessarily
543+ # the case! Because of that we only use batch size of 1
544+ stop_tokens = split [0 ].stop_sequence
545+
546+ max_new_tokens = self .config .generation_parameters .max_new_tokens or split [0 ].generation_size
547+ returns_logits = split [0 ].use_logits
548+ num_samples = split [0 ].num_samples
549+ contexts = [self .prompt_manager .prepare_prompt (doc ) for doc in split ]
550+ tokenized = self .tokenizer (contexts , add_special_tokens = self .add_special_tokens )
551+
552+ # The main question for this step is the following:
553+ # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
554+ # of losing some meaning, or have some generations that are exceedingly short?
555+ # The choice we go for here is to avoid truncating the prompt if we can, since it
556+ # should have been managed by the prompt creator/few shot manager if requested by the user.
557+ inputs = tokenized ["input_ids" ]
558+ context_size = len (inputs [0 ])
559+
560+ # left truncate the inputs to the maximum length
561+ if max_new_tokens is not None :
562+ if context_size + max_new_tokens > self .max_length :
563+ logger .warning (
564+ f"{ context_size + max_new_tokens = } which is greater than { self .max_length = } . Truncating context to { self .max_length - max_new_tokens } tokens."
565+ )
566+ context_size = self .max_length - max_new_tokens
567+ if context_size < 0 :
568+ logger .critical (
569+ f"{ context_size = } is less than 0, either reduce the max_new_tokens or increase model max length."
570+ )
571+ raise ValueError ("Context size is less than 0." )
572+ inputs = [input [- context_size :] for input in inputs ]
573+ else :
574+ if context_size > self .max_length :
575+ logger .warning (
576+ f"{ context_size = } which is greater than { self .max_length = } . Truncating context to { self .max_length } tokens."
577+ )
578+ context_size = self .max_length
579+ inputs = [input [- context_size :] for input in inputs ]
580+
581+ _outputs = self ._generate (
582+ inputs = inputs ,
583+ max_new_tokens = max_new_tokens ,
584+ stop_tokens = stop_tokens ,
585+ returns_logits = returns_logits ,
586+ num_samples = num_samples ,
587+ continuous_batching = True ,
588+ )
589+
590+ for req_id , _output in _outputs .items ():
591+ output_token_ids = []
592+ logprobs_raw = []
593+ result = []
594+
595+ # for output in _output.outputs:
596+ output_token_ids .append (_output .generated_tokens )
597+ # logprobs_raw.append(output.logprobs)
598+ result .append (self .tokenizer .decode (_output .generated_tokens ))
599+
600+ if logprobs_raw and output_token_ids and False :
601+ logprobs = [logprobs_raw [0 ][token_id ].logprob for token_id in output_token_ids [0 ]]
602+ else :
603+ logprobs = []
604+
605+ input_token_ids = _output .prompt_ids
606+ cur_response = ModelResponse (
607+ text = result ,
608+ logprobs = logprobs ,
609+ output_tokens = output_token_ids ,
610+ input_tokens = input_token_ids ,
611+ )
612+ results .append (cur_response )
613+
614+ return dataset .get_original_order (results )
615+
616+ def _padded_greedy_until (
504617 self ,
505618 docs : list [Doc ],
506619 ) -> list [ModelResponse ]:
@@ -613,12 +726,43 @@ def greedy_until(
613726 stop_tokens = stop_tokens ,
614727 returns_logits = False ,
615728 num_samples = num_samples ,
729+ continuous_batching = False ,
616730 )
617731 results .extend (cur_reponses )
618732
619733 return dataset .get_original_order (results )
620734
621- def _generate (
735+ def greedy_until (
736+ self ,
737+ docs : list [Doc ],
738+ ) -> list [ModelResponse ]:
739+ if self .continuous_batching :
740+ return self ._continuous_greedy_until (docs )
741+ else :
742+ return self ._padded_greedy_until (docs )
743+
744+ def _generate_continuous (
745+ self ,
746+ inputs : list [list [int ]],
747+ max_new_tokens : Optional [int ] = None ,
748+ stop_tokens : Optional [list [str ]] = None ,
749+ returns_logits : Optional [bool ] = False ,
750+ num_samples : int = 1 ,
751+ generate : bool = True ,
752+ ) -> Dict [str , ModelResponse ]:
753+ # Compute model generation
754+ self .model .generation_config .use_cuda_graph = False # Disable CUDA graph for batch generation
755+ self .model .generation_config .max_batch_tokens = 256 # Disable CUDA graph for batch generation
756+ # self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
757+ batch_outputs = self .model .generate_batch (
758+ inputs = inputs ,
759+ generation_config = self .model .generation_config ,
760+ # You can pass request-specific overrides here, e.g., max_new_tokens=100
761+ )
762+
763+ return batch_outputs
764+
765+ def _generate_padded (
622766 self ,
623767 batch : Batch ,
624768 max_new_tokens : int ,
@@ -704,6 +848,16 @@ def _generate(
704848
705849 return all_responses
706850
851+ def _generate (
852+ self ,
853+ continuous_batching : bool ,
854+ ** kwargs ,
855+ ) -> list [ModelResponse ]:
856+ if continuous_batching :
857+ return self ._generate_continuous (** kwargs )
858+ else :
859+ return self ._generate_padded (** kwargs )
860+
707861 def loglikelihood (
708862 self ,
709863 docs : list [Doc ],
0 commit comments