23
23
import logging
24
24
import os
25
25
from datetime import timedelta
26
- from typing import Optional , Tuple , Union
26
+ from typing import Dict , Optional , Tuple , Union
27
27
28
28
import torch
29
29
import torch .nn .functional as F
41
41
BitsAndBytesConfig ,
42
42
PretrainedConfig ,
43
43
)
44
+ from transformers .generation .configuration_utils import GenerationConfig
44
45
from transformers .generation .utils import GenerateOutput
45
46
from transformers .models .auto .modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
46
47
@@ -108,6 +109,8 @@ class TransformersModelConfig(ModelConfig):
108
109
True forces adding space, False removes leading space if present.
109
110
pairwise_tokenization (bool):
110
111
Whether to tokenize context and continuation separately or together. Defaults to False.
112
+ continuous_batching (bool):
113
+ Whether to use continuous batching for generation. Defaults to False.
111
114
112
115
Example:
113
116
```python
@@ -143,6 +146,7 @@ class TransformersModelConfig(ModelConfig):
143
146
compile : bool = False
144
147
multichoice_continuations_start_space : bool | None = None
145
148
pairwise_tokenization : bool = False
149
+ continuous_batching : bool = False
146
150
147
151
def model_post_init (self , __context ):
148
152
if self .multichoice_continuations_start_space is True :
@@ -185,7 +189,9 @@ def __init__(
185
189
self ._add_special_tokens = config .add_special_tokens or False
186
190
self .pairwise_tokenization = config .pairwise_tokenization
187
191
self .batch_size = config .batch_size
192
+ self .continuous_batching = config .continuous_batching
188
193
self .transformers_config = config .get_transformers_config ()
194
+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
189
195
190
196
self .model_sha = config .get_model_sha ()
191
197
self ._max_length = self ._init_max_length ()
@@ -206,8 +212,6 @@ def __init__(
206
212
207
213
self .model_name = _simplify_name (config .model_name )
208
214
209
- self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
210
-
211
215
if is_accelerate_available ():
212
216
model_size , _ = calculate_maximum_sizes (self .model )
213
217
model_size = convert_bytes (model_size )
@@ -252,14 +256,15 @@ def from_model(
252
256
253
257
# Instanciate the object without using __init__
254
258
self = cls .__new__ (cls )
255
- self .config = config
256
259
self .transformers_config = model .config
257
- self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
260
+ self .config = config if config is not None else TransformersModelConfig (model_name = model .config .name_or_path )
261
+ if config is not None :
262
+ self .generation_config_dict = config .generation_parameters .to_transformers_dict ()
258
263
self ._max_length = self ._init_max_length ()
259
264
self ._tokenizer = self ._create_auto_tokenizer ()
260
- self .batch_size = config . batch_size
265
+ self .batch_size = getattr ( config , " batch_size" , None )
261
266
self .model_name = _simplify_name (model .name_or_path )
262
- self .model_sha = config .get_model_sha ()
267
+ self .model_sha = self . config .get_model_sha ()
263
268
264
269
# If model_parallel is not set we compare the number of processes with the number of GPUs
265
270
self .model = model
@@ -398,6 +403,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
398
403
# model.to(self.device)
399
404
model .eval ()
400
405
torch .set_grad_enabled (False )
406
+ if self .continuous_batching :
407
+ generation_config = GenerationConfig (
408
+ ** self .generation_config_dict ,
409
+ )
410
+ model .generation_config = generation_config
401
411
402
412
if self .config .compile :
403
413
try :
@@ -500,7 +510,110 @@ def forward_batch(batch_size):
500
510
logger .info (f"Determined largest batch size: { batch_size } " )
501
511
return batch_size
502
512
503
- def greedy_until (
513
+ def _continuous_greedy_until (
514
+ self ,
515
+ docs : list [Doc ],
516
+ ) -> list [ModelResponse ]:
517
+ """
518
+ Generates responses using a greedy decoding strategy until certain ending conditions are met.
519
+
520
+ Args:
521
+ requests (list[Request]): list of requests containing the context and ending conditions.
522
+ override_bs (int, optional): Override the batch size for generation. Defaults to None.
523
+
524
+ Returns:
525
+ list[GenerateReturn]: list of generated responses.
526
+ """
527
+ dataset = GenerativeTaskDataset (requests = docs , num_dataset_splits = self .DATASET_SPLITS )
528
+ results = []
529
+
530
+ for split in tqdm (
531
+ dataset .splits_iterator (),
532
+ total = dataset .num_dataset_splits ,
533
+ desc = "Splits" ,
534
+ position = 0 ,
535
+ disable = False , # self.disable_tqdm,
536
+ ):
537
+ # For chat models, generation stops with EOS token, so we don't need to specify stop tokens
538
+ if self .use_chat_template :
539
+ stop_tokens = []
540
+ else :
541
+ # NOTE: we are assuming all items in a batch behave similarly (same
542
+ # stop_tokens and max_tokens genrated) which is not necessarily
543
+ # the case! Because of that we only use batch size of 1
544
+ stop_tokens = split [0 ].stop_sequence
545
+
546
+ max_new_tokens = self .config .generation_parameters .max_new_tokens or split [0 ].generation_size
547
+ returns_logits = split [0 ].use_logits
548
+ num_samples = split [0 ].num_samples
549
+ contexts = [self .prompt_manager .prepare_prompt (doc ) for doc in split ]
550
+ tokenized = self .tokenizer (contexts , add_special_tokens = self .add_special_tokens )
551
+
552
+ # The main question for this step is the following:
553
+ # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
554
+ # of losing some meaning, or have some generations that are exceedingly short?
555
+ # The choice we go for here is to avoid truncating the prompt if we can, since it
556
+ # should have been managed by the prompt creator/few shot manager if requested by the user.
557
+ inputs = tokenized ["input_ids" ]
558
+ context_size = len (inputs [0 ])
559
+
560
+ # left truncate the inputs to the maximum length
561
+ if max_new_tokens is not None :
562
+ if context_size + max_new_tokens > self .max_length :
563
+ logger .warning (
564
+ f"{ context_size + max_new_tokens = } which is greater than { self .max_length = } . Truncating context to { self .max_length - max_new_tokens } tokens."
565
+ )
566
+ context_size = self .max_length - max_new_tokens
567
+ if context_size < 0 :
568
+ logger .critical (
569
+ f"{ context_size = } is less than 0, either reduce the max_new_tokens or increase model max length."
570
+ )
571
+ raise ValueError ("Context size is less than 0." )
572
+ inputs = [input [- context_size :] for input in inputs ]
573
+ else :
574
+ if context_size > self .max_length :
575
+ logger .warning (
576
+ f"{ context_size = } which is greater than { self .max_length = } . Truncating context to { self .max_length } tokens."
577
+ )
578
+ context_size = self .max_length
579
+ inputs = [input [- context_size :] for input in inputs ]
580
+
581
+ _outputs = self ._generate (
582
+ inputs = inputs ,
583
+ max_new_tokens = max_new_tokens ,
584
+ stop_tokens = stop_tokens ,
585
+ returns_logits = returns_logits ,
586
+ num_samples = num_samples ,
587
+ continuous_batching = True ,
588
+ )
589
+
590
+ for req_id , _output in _outputs .items ():
591
+ output_token_ids = []
592
+ logprobs_raw = []
593
+ result = []
594
+
595
+ # for output in _output.outputs:
596
+ output_token_ids .append (_output .generated_tokens )
597
+ # logprobs_raw.append(output.logprobs)
598
+ result .append (self .tokenizer .decode (_output .generated_tokens ))
599
+
600
+ if logprobs_raw and output_token_ids and False :
601
+ logprobs = [logprobs_raw [0 ][token_id ].logprob for token_id in output_token_ids [0 ]]
602
+ else :
603
+ logprobs = []
604
+
605
+ input_token_ids = _output .prompt_ids
606
+ cur_response = ModelResponse (
607
+ text = result ,
608
+ logprobs = logprobs ,
609
+ output_tokens = output_token_ids ,
610
+ input_tokens = input_token_ids ,
611
+ )
612
+ results .append (cur_response )
613
+
614
+ return dataset .get_original_order (results )
615
+
616
+ def _padded_greedy_until (
504
617
self ,
505
618
docs : list [Doc ],
506
619
) -> list [ModelResponse ]:
@@ -613,12 +726,43 @@ def greedy_until(
613
726
stop_tokens = stop_tokens ,
614
727
returns_logits = False ,
615
728
num_samples = num_samples ,
729
+ continuous_batching = False ,
616
730
)
617
731
results .extend (cur_reponses )
618
732
619
733
return dataset .get_original_order (results )
620
734
621
- def _generate (
735
+ def greedy_until (
736
+ self ,
737
+ docs : list [Doc ],
738
+ ) -> list [ModelResponse ]:
739
+ if self .continuous_batching :
740
+ return self ._continuous_greedy_until (docs )
741
+ else :
742
+ return self ._padded_greedy_until (docs )
743
+
744
+ def _generate_continuous (
745
+ self ,
746
+ inputs : list [list [int ]],
747
+ max_new_tokens : Optional [int ] = None ,
748
+ stop_tokens : Optional [list [str ]] = None ,
749
+ returns_logits : Optional [bool ] = False ,
750
+ num_samples : int = 1 ,
751
+ generate : bool = True ,
752
+ ) -> Dict [str , ModelResponse ]:
753
+ # Compute model generation
754
+ self .model .generation_config .use_cuda_graph = False # Disable CUDA graph for batch generation
755
+ self .model .generation_config .max_batch_tokens = 256 # Disable CUDA graph for batch generation
756
+ # self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
757
+ batch_outputs = self .model .generate_batch (
758
+ inputs = inputs ,
759
+ generation_config = self .model .generation_config ,
760
+ # You can pass request-specific overrides here, e.g., max_new_tokens=100
761
+ )
762
+
763
+ return batch_outputs
764
+
765
+ def _generate_padded (
622
766
self ,
623
767
batch : Batch ,
624
768
max_new_tokens : int ,
@@ -704,6 +848,16 @@ def _generate(
704
848
705
849
return all_responses
706
850
851
+ def _generate (
852
+ self ,
853
+ continuous_batching : bool ,
854
+ ** kwargs ,
855
+ ) -> list [ModelResponse ]:
856
+ if continuous_batching :
857
+ return self ._generate_continuous (** kwargs )
858
+ else :
859
+ return self ._generate_padded (** kwargs )
860
+
707
861
def loglikelihood (
708
862
self ,
709
863
docs : list [Doc ],
0 commit comments