diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index fef8b0b5b..5d1132c61 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -92,6 +92,8 @@ class TransformersModelConfig(ModelConfig): Additional keyword arguments passed to `from_pretrained`. Defaults to empty dict. add_special_tokens (bool): Whether to add special tokens during tokenization. Defaults to True. + skip_special_tokens (bool): + Whether the tokenizer should output special tokens back during generation. Needed for reasoning models. Defaults to False model_parallel (bool | None): Whether to use model parallelism across multiple GPUs. If None, automatically determined based on available GPUs and model size. @@ -139,6 +141,7 @@ class TransformersModelConfig(ModelConfig): max_length: PositiveInt | None = None model_loading_kwargs: dict = Field(default_factory=dict) add_special_tokens: bool = True + skip_special_tokens: bool = False model_parallel: bool | None = None dtype: str | None = None device: Union[int, str] = "cuda" @@ -187,6 +190,7 @@ def __init__( self._device = self.accelerator.device self.multichoice_continuations_start_space = config.multichoice_continuations_start_space self._add_special_tokens = config.add_special_tokens or False + self.skip_special_tokens = config.skip_special_tokens or False self.pairwise_tokenization = config.pairwise_tokenization self.batch_size = config.batch_size self.continuous_batching = config.continuous_batching @@ -244,6 +248,7 @@ def from_model( tokenizer_name: str = None, # custom tokenizer trust_remote_code: bool = False, add_special_tokens: bool = True, + skip_special_tokens: bool = False, pairwise_tokenization: bool = False, multichoice_continuations_start_space: bool = None, ): @@ -280,6 +285,7 @@ def from_model( self.use_chat_template = uses_chat_template(self._tokenizer) self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False + self.skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else False self.pairwise_tokenization = pairwise_tokenization self.multichoice_continuations_start_space = multichoice_continuations_start_space @@ -395,7 +401,8 @@ def _create_auto_model(self) -> transformers.PreTrainedModel: self.config.model_name, revision=revision, max_memory=max_memory, - device_map=device_map, + tp_plan="auto", + #device_map=device_map, torch_dtype=torch_dtype, trust_remote_code=self.config.trust_remote_code, **kwargs, @@ -595,7 +602,9 @@ def _continuous_greedy_until( # for output in _output.outputs: output_token_ids.append(_output.generated_tokens) # logprobs_raw.append(output.logprobs) - result.append(self.tokenizer.decode(_output.generated_tokens)) + result.append( + self.tokenizer.decode(_output.generated_tokens, skip_special_tokens=self.skip_special_tokens) + ) if logprobs_raw and output_token_ids and False: logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]] @@ -646,7 +655,7 @@ def _padded_greedy_until( tokenized_context = self.tokenizer(context) # Longest context in the current split is the first item (since we sort reversed) - longest_context_continuation_size_in_split = len(tokenized_context) + split[0].generation_size + longest_context_continuation_size_in_split = len(tokenized_context["input_ids"]) + split[0].generation_size max_context_continuation_size_allowed = min( longest_context_continuation_size_in_split, self.max_length ) @@ -669,12 +678,12 @@ def _padded_greedy_until( # For chat models, generation stops with EOS token, so we don't need to specify stop tokens if self.use_chat_template: - stop_tokens = [] + stop_tokens = [self.tokenizer.eos_token] else: # NOTE: we are assuming all items in a batch behave similarly (same # stop_tokens and max_tokens genrated) which is not necessarily # the case! Because of that we only use batch size of 1 - stop_tokens = batch[0].stop_sequences + stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences max_new_tokens = batch[0].generation_size num_samples = batch[0].num_samples @@ -750,16 +759,27 @@ def _generate_continuous( num_samples: int = 1, generate: bool = True, ) -> Dict[str, ModelResponse]: + if num_samples > 1 and self.generation_config_dict["temperature"] == 0: + raise ValueError( + "You cannot generate multiple samples with temperature=0. Please set temperature > 0. Or use a non sampling metric." + ) + # Compute model generation - self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation - self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation - # self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation + self.model.generation_config.max_new_tokens=max_new_tokens + self.model.generation_config.eos_token_id=self.tokenizer.eos_token_id + self.model.generation_config.num_return_sequences=num_samples + self.model.generation_config.output_logits=returns_logits + self.model.generation_config.renormalize_logits=True + self.model.generation_config.num_blocks=4096 + self.model.generation_config.block_size=256 + self.model.generation_config.max_batch_tokens=16, # Disable CUDA graph for batch generation + self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation + self.model.generation_config.use_cuda_graph=False # Disable CUDA graph for batch generation batch_outputs = self.model.generate_batch( inputs=inputs, generation_config=self.model.generation_config, # You can pass request-specific overrides here, e.g., max_new_tokens=100 ) - return batch_outputs def _generate_padded( @@ -1189,6 +1209,9 @@ def pad_and_gather( output_tensor = self.accelerator.gather(output_tensor) return output_tensor, length_tensor + def tok_decode(self, tokens: torch.LongTensor) -> list[str]: + return self.tokenizer.batch_decode(tokens, skip_special_tokens=self.skip_special_tokens) + class MultiTokenEOSCriteria(transformers.StoppingCriteria): """Criteria to stop on the specified multi-token sequence."""