Skip to content

Commit 99bfd9f

Browse files
NathanHBArthurZuckerclefourrier
authored
Adds continuous batching (#850)
Add necessary changes to call generate with CB Linked PR: huggingface/transformers#38085 This works: ```python from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.pipeline import Pipeline, PipelineParameters, ParallelismManager from lighteval.models.endpoints.inference_providers_model import ( InferenceProvidersModelConfig, ) from lighteval.models.transformers.transformers_model import TransformersModel import torch from transformers import AutoModelForCausalLM, GenerationConfig MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" PROVIDER = "hf-inference" BENCHMARKS = "lighteval|gsm8k|0|0" evaluation_tracker = EvaluationTracker(output_dir="./results") pipeline_params = PipelineParameters( use_chat_template=True, launcher_type=ParallelismManager.NONE, max_samples=None ) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-3b-Instruct", attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" ) # Configure generation parameters generation_config = GenerationConfig( max_new_tokens=10, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.pad_token_id, num_blocks=2048, block_size=256, ) model.generation_config = generation_config model = TransformersModel.from_model(model) pipeline = Pipeline( model=model, pipeline_parameters=pipeline_params, evaluation_tracker=evaluation_tracker, tasks=BENCHMARKS, ) pipeline.evaluate() results = pipeline.get_results()["results"] print(results) ``` --------- Co-authored-by: Arthur Zucker <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 1404ba1 commit 99bfd9f

File tree

5 files changed

+179
-10
lines changed

5 files changed

+179
-10
lines changed

examples/model_configs/transformers_model.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ model_parameters:
55
compile: false
66
model_parallel: false
77
batch_size: 1
8-
multichoice_continuations_start_space: null # If true/false, will force multiple choice continuations to start/not start with a space. If none, will do nothing
8+
continuous_batching: false
9+
model_loading_kwargs:
10+
attn_implementation: "eager"
11+
#tp_plan: "auto"
912
generation_parameters:
13+
#num_blocks: 4096
14+
#block_size: 64
15+
#max_new_tokens: 256
1016
temperature: 0.0
1117
top_p: 0.9

src/lighteval/models/model_input.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626

2727
class GenerationParameters(BaseModel, extra="forbid"):
28+
num_blocks: NonNegativeInt | None = None # transformers
29+
block_size: NonNegativeInt | None = None # transformers
30+
2831
early_stopping: bool | None = None # transformers
2932
repetition_penalty: NonNegativeFloat | None = None # vllm, transformers, tgi, sglang
3033
frequency_penalty: NonNegativeFloat | None = None # vllm, tgi, sglang
@@ -186,6 +189,8 @@ def to_transformers_dict(self) -> dict:
186189
"repetition_penalty": self.repetition_penalty,
187190
"length_penalty": self.length_penalty,
188191
"output_scores": True,
192+
"num_blocks": self.num_blocks,
193+
"block_size": self.block_size,
189194
"return_dict_in_generate": True,
190195
}
191196
return {k: v for k, v in args.items() if v is not None}

src/lighteval/models/transformers/transformers_model.py

Lines changed: 163 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import os
2525
from datetime import timedelta
26-
from typing import Optional, Tuple, Union
26+
from typing import Dict, Optional, Tuple, Union
2727

2828
import torch
2929
import torch.nn.functional as F
@@ -41,6 +41,7 @@
4141
BitsAndBytesConfig,
4242
PretrainedConfig,
4343
)
44+
from transformers.generation.configuration_utils import GenerationConfig
4445
from transformers.generation.utils import GenerateOutput
4546
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4647

@@ -108,6 +109,8 @@ class TransformersModelConfig(ModelConfig):
108109
True forces adding space, False removes leading space if present.
109110
pairwise_tokenization (bool):
110111
Whether to tokenize context and continuation separately or together. Defaults to False.
112+
continuous_batching (bool):
113+
Whether to use continuous batching for generation. Defaults to False.
111114
112115
Example:
113116
```python
@@ -143,6 +146,7 @@ class TransformersModelConfig(ModelConfig):
143146
compile: bool = False
144147
multichoice_continuations_start_space: bool | None = None
145148
pairwise_tokenization: bool = False
149+
continuous_batching: bool = False
146150

147151
def model_post_init(self, __context):
148152
if self.multichoice_continuations_start_space is True:
@@ -185,7 +189,9 @@ def __init__(
185189
self._add_special_tokens = config.add_special_tokens or False
186190
self.pairwise_tokenization = config.pairwise_tokenization
187191
self.batch_size = config.batch_size
192+
self.continuous_batching = config.continuous_batching
188193
self.transformers_config = config.get_transformers_config()
194+
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
189195

190196
self.model_sha = config.get_model_sha()
191197
self._max_length = self._init_max_length()
@@ -206,8 +212,6 @@ def __init__(
206212

207213
self.model_name = _simplify_name(config.model_name)
208214

209-
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
210-
211215
if is_accelerate_available():
212216
model_size, _ = calculate_maximum_sizes(self.model)
213217
model_size = convert_bytes(model_size)
@@ -252,14 +256,15 @@ def from_model(
252256

253257
# Instanciate the object without using __init__
254258
self = cls.__new__(cls)
255-
self.config = config
256259
self.transformers_config = model.config
257-
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
260+
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
261+
if config is not None:
262+
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
258263
self._max_length = self._init_max_length()
259264
self._tokenizer = self._create_auto_tokenizer()
260-
self.batch_size = config.batch_size
265+
self.batch_size = getattr(config, "batch_size", None)
261266
self.model_name = _simplify_name(model.name_or_path)
262-
self.model_sha = config.get_model_sha()
267+
self.model_sha = self.config.get_model_sha()
263268

264269
# If model_parallel is not set we compare the number of processes with the number of GPUs
265270
self.model = model
@@ -398,6 +403,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
398403
# model.to(self.device)
399404
model.eval()
400405
torch.set_grad_enabled(False)
406+
if self.continuous_batching:
407+
generation_config = GenerationConfig(
408+
**self.generation_config_dict,
409+
)
410+
model.generation_config = generation_config
401411

402412
if self.config.compile:
403413
try:
@@ -500,7 +510,110 @@ def forward_batch(batch_size):
500510
logger.info(f"Determined largest batch size: {batch_size}")
501511
return batch_size
502512

503-
def greedy_until(
513+
def _continuous_greedy_until(
514+
self,
515+
docs: list[Doc],
516+
) -> list[ModelResponse]:
517+
"""
518+
Generates responses using a greedy decoding strategy until certain ending conditions are met.
519+
520+
Args:
521+
requests (list[Request]): list of requests containing the context and ending conditions.
522+
override_bs (int, optional): Override the batch size for generation. Defaults to None.
523+
524+
Returns:
525+
list[GenerateReturn]: list of generated responses.
526+
"""
527+
dataset = GenerativeTaskDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS)
528+
results = []
529+
530+
for split in tqdm(
531+
dataset.splits_iterator(),
532+
total=dataset.num_dataset_splits,
533+
desc="Splits",
534+
position=0,
535+
disable=False, # self.disable_tqdm,
536+
):
537+
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
538+
if self.use_chat_template:
539+
stop_tokens = []
540+
else:
541+
# NOTE: we are assuming all items in a batch behave similarly (same
542+
# stop_tokens and max_tokens genrated) which is not necessarily
543+
# the case! Because of that we only use batch size of 1
544+
stop_tokens = split[0].stop_sequence
545+
546+
max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size
547+
returns_logits = split[0].use_logits
548+
num_samples = split[0].num_samples
549+
contexts = [self.prompt_manager.prepare_prompt(doc) for doc in split]
550+
tokenized = self.tokenizer(contexts, add_special_tokens=self.add_special_tokens)
551+
552+
# The main question for this step is the following:
553+
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
554+
# of losing some meaning, or have some generations that are exceedingly short?
555+
# The choice we go for here is to avoid truncating the prompt if we can, since it
556+
# should have been managed by the prompt creator/few shot manager if requested by the user.
557+
inputs = tokenized["input_ids"]
558+
context_size = len(inputs[0])
559+
560+
# left truncate the inputs to the maximum length
561+
if max_new_tokens is not None:
562+
if context_size + max_new_tokens > self.max_length:
563+
logger.warning(
564+
f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
565+
)
566+
context_size = self.max_length - max_new_tokens
567+
if context_size < 0:
568+
logger.critical(
569+
f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length."
570+
)
571+
raise ValueError("Context size is less than 0.")
572+
inputs = [input[-context_size:] for input in inputs]
573+
else:
574+
if context_size > self.max_length:
575+
logger.warning(
576+
f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens."
577+
)
578+
context_size = self.max_length
579+
inputs = [input[-context_size:] for input in inputs]
580+
581+
_outputs = self._generate(
582+
inputs=inputs,
583+
max_new_tokens=max_new_tokens,
584+
stop_tokens=stop_tokens,
585+
returns_logits=returns_logits,
586+
num_samples=num_samples,
587+
continuous_batching=True,
588+
)
589+
590+
for req_id, _output in _outputs.items():
591+
output_token_ids = []
592+
logprobs_raw = []
593+
result = []
594+
595+
# for output in _output.outputs:
596+
output_token_ids.append(_output.generated_tokens)
597+
# logprobs_raw.append(output.logprobs)
598+
result.append(self.tokenizer.decode(_output.generated_tokens))
599+
600+
if logprobs_raw and output_token_ids and False:
601+
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]]
602+
else:
603+
logprobs = []
604+
605+
input_token_ids = _output.prompt_ids
606+
cur_response = ModelResponse(
607+
text=result,
608+
logprobs=logprobs,
609+
output_tokens=output_token_ids,
610+
input_tokens=input_token_ids,
611+
)
612+
results.append(cur_response)
613+
614+
return dataset.get_original_order(results)
615+
616+
def _padded_greedy_until(
504617
self,
505618
docs: list[Doc],
506619
) -> list[ModelResponse]:
@@ -613,12 +726,43 @@ def greedy_until(
613726
stop_tokens=stop_tokens,
614727
returns_logits=False,
615728
num_samples=num_samples,
729+
continuous_batching=False,
616730
)
617731
results.extend(cur_reponses)
618732

619733
return dataset.get_original_order(results)
620734

621-
def _generate(
735+
def greedy_until(
736+
self,
737+
docs: list[Doc],
738+
) -> list[ModelResponse]:
739+
if self.continuous_batching:
740+
return self._continuous_greedy_until(docs)
741+
else:
742+
return self._padded_greedy_until(docs)
743+
744+
def _generate_continuous(
745+
self,
746+
inputs: list[list[int]],
747+
max_new_tokens: Optional[int] = None,
748+
stop_tokens: Optional[list[str]] = None,
749+
returns_logits: Optional[bool] = False,
750+
num_samples: int = 1,
751+
generate: bool = True,
752+
) -> Dict[str, ModelResponse]:
753+
# Compute model generation
754+
self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation
755+
self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation
756+
# self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
757+
batch_outputs = self.model.generate_batch(
758+
inputs=inputs,
759+
generation_config=self.model.generation_config,
760+
# You can pass request-specific overrides here, e.g., max_new_tokens=100
761+
)
762+
763+
return batch_outputs
764+
765+
def _generate_padded(
622766
self,
623767
batch: Batch,
624768
max_new_tokens: int,
@@ -704,6 +848,16 @@ def _generate(
704848

705849
return all_responses
706850

851+
def _generate(
852+
self,
853+
continuous_batching: bool,
854+
**kwargs,
855+
) -> list[ModelResponse]:
856+
if continuous_batching:
857+
return self._generate_continuous(**kwargs)
858+
else:
859+
return self._generate_padded(**kwargs)
860+
707861
def loglikelihood(
708862
self,
709863
docs: list[Doc],

tests/models/endpoints/test_endpoint_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class TestInferenceEndpointModelConfig:
5252
"add_special_tokens": True,
5353
"system_prompt": None,
5454
"generation_parameters": {
55+
"num_blocks": None,
56+
"block_size": None,
5557
"early_stopping": None,
5658
"frequency_penalty": None,
5759
"length_penalty": None,

tests/models/endpoints/test_tgi_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class TestTGIModelConfig:
3838
"model_name": None,
3939
"system_prompt": None,
4040
"generation_parameters": {
41+
"block_size": None,
42+
"num_blocks": None,
4143
"early_stopping": None,
4244
"frequency_penalty": None,
4345
"length_penalty": None,

0 commit comments

Comments
 (0)