Skip to content

Debug continuous batching #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class TransformersModelConfig(ModelConfig):
Additional keyword arguments passed to `from_pretrained`. Defaults to empty dict.
add_special_tokens (bool):
Whether to add special tokens during tokenization. Defaults to True.
skip_special_tokens (bool):
Whether the tokenizer should output special tokens back during generation. Needed for reasoning models. Defaults to False
model_parallel (bool | None):
Whether to use model parallelism across multiple GPUs. If None, automatically
determined based on available GPUs and model size.
Expand Down Expand Up @@ -139,6 +141,7 @@ class TransformersModelConfig(ModelConfig):
max_length: PositiveInt | None = None
model_loading_kwargs: dict = Field(default_factory=dict)
add_special_tokens: bool = True
skip_special_tokens: bool = False
model_parallel: bool | None = None
dtype: str | None = None
device: Union[int, str] = "cuda"
Expand Down Expand Up @@ -187,6 +190,7 @@ def __init__(
self._device = self.accelerator.device
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space
self._add_special_tokens = config.add_special_tokens or False
self.skip_special_tokens = config.skip_special_tokens or False
self.pairwise_tokenization = config.pairwise_tokenization
self.batch_size = config.batch_size
self.continuous_batching = config.continuous_batching
Expand Down Expand Up @@ -244,6 +248,7 @@ def from_model(
tokenizer_name: str = None, # custom tokenizer
trust_remote_code: bool = False,
add_special_tokens: bool = True,
skip_special_tokens: bool = False,
pairwise_tokenization: bool = False,
multichoice_continuations_start_space: bool = None,
):
Expand Down Expand Up @@ -280,6 +285,7 @@ def from_model(

self.use_chat_template = uses_chat_template(self._tokenizer)
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
self.skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else False
self.pairwise_tokenization = pairwise_tokenization
self.multichoice_continuations_start_space = multichoice_continuations_start_space

Expand Down Expand Up @@ -395,7 +401,8 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
self.config.model_name,
revision=revision,
max_memory=max_memory,
device_map=device_map,
tp_plan="auto",
#device_map=device_map,
torch_dtype=torch_dtype,
trust_remote_code=self.config.trust_remote_code,
**kwargs,
Expand Down Expand Up @@ -595,7 +602,9 @@ def _continuous_greedy_until(
# for output in _output.outputs:
output_token_ids.append(_output.generated_tokens)
# logprobs_raw.append(output.logprobs)
result.append(self.tokenizer.decode(_output.generated_tokens))
result.append(
self.tokenizer.decode(_output.generated_tokens, skip_special_tokens=self.skip_special_tokens)
)

if logprobs_raw and output_token_ids and False:
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]]
Expand Down Expand Up @@ -646,7 +655,7 @@ def _padded_greedy_until(
tokenized_context = self.tokenizer(context)

# Longest context in the current split is the first item (since we sort reversed)
longest_context_continuation_size_in_split = len(tokenized_context) + split[0].generation_size
longest_context_continuation_size_in_split = len(tokenized_context["input_ids"]) + split[0].generation_size
max_context_continuation_size_allowed = min(
longest_context_continuation_size_in_split, self.max_length
)
Expand All @@ -669,12 +678,12 @@ def _padded_greedy_until(

# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
if self.use_chat_template:
stop_tokens = []
stop_tokens = [self.tokenizer.eos_token]
else:
# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we only use batch size of 1
stop_tokens = batch[0].stop_sequences
stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences

max_new_tokens = batch[0].generation_size
num_samples = batch[0].num_samples
Expand Down Expand Up @@ -750,16 +759,27 @@ def _generate_continuous(
num_samples: int = 1,
generate: bool = True,
) -> Dict[str, ModelResponse]:
if num_samples > 1 and self.generation_config_dict["temperature"] == 0:
raise ValueError(
"You cannot generate multiple samples with temperature=0. Please set temperature > 0. Or use a non sampling metric."
)

# Compute model generation
self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation
self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation
# self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
self.model.generation_config.max_new_tokens=max_new_tokens
self.model.generation_config.eos_token_id=self.tokenizer.eos_token_id
self.model.generation_config.num_return_sequences=num_samples
self.model.generation_config.output_logits=returns_logits
self.model.generation_config.renormalize_logits=True
self.model.generation_config.num_blocks=4096
self.model.generation_config.block_size=256
self.model.generation_config.max_batch_tokens=16, # Disable CUDA graph for batch generation
self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
self.model.generation_config.use_cuda_graph=False # Disable CUDA graph for batch generation
batch_outputs = self.model.generate_batch(
inputs=inputs,
generation_config=self.model.generation_config,
# You can pass request-specific overrides here, e.g., max_new_tokens=100
)

return batch_outputs

def _generate_padded(
Expand Down Expand Up @@ -1189,6 +1209,9 @@ def pad_and_gather(
output_tensor = self.accelerator.gather(output_tensor)
return output_tensor, length_tensor

def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=self.skip_special_tokens)


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
Expand Down
Loading