Skip to content

Commit 64f93b0

Browse files
authored
Number of fixes to run accelerate evaluations (#898)
* init * fix * cb fix * rm sb * style * true by default
1 parent 865335e commit 64f93b0

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

src/lighteval/models/transformers/transformers_model.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class TransformersModelConfig(ModelConfig):
9292
Additional keyword arguments passed to `from_pretrained`. Defaults to empty dict.
9393
add_special_tokens (bool):
9494
Whether to add special tokens during tokenization. Defaults to True.
95+
skip_special_tokens (bool):
96+
Whether the tokenizer should output special tokens back during generation. Needed for reasoning models. Defaults to True
9597
model_parallel (bool | None):
9698
Whether to use model parallelism across multiple GPUs. If None, automatically
9799
determined based on available GPUs and model size.
@@ -139,6 +141,7 @@ class TransformersModelConfig(ModelConfig):
139141
max_length: PositiveInt | None = None
140142
model_loading_kwargs: dict = Field(default_factory=dict)
141143
add_special_tokens: bool = True
144+
skip_special_tokens: bool = True
142145
model_parallel: bool | None = None
143146
dtype: str | None = None
144147
device: Union[int, str] = "cuda"
@@ -187,6 +190,7 @@ def __init__(
187190
self._device = self.accelerator.device
188191
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space
189192
self._add_special_tokens = config.add_special_tokens or False
193+
self.skip_special_tokens = config.skip_special_tokens or True
190194
self.pairwise_tokenization = config.pairwise_tokenization
191195
self.batch_size = config.batch_size
192196
self.continuous_batching = config.continuous_batching
@@ -244,6 +248,7 @@ def from_model(
244248
tokenizer_name: str = None, # custom tokenizer
245249
trust_remote_code: bool = False,
246250
add_special_tokens: bool = True,
251+
skip_special_tokens: bool = True,
247252
pairwise_tokenization: bool = False,
248253
multichoice_continuations_start_space: bool = None,
249254
):
@@ -280,6 +285,7 @@ def from_model(
280285

281286
self.use_chat_template = uses_chat_template(self._tokenizer)
282287
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
288+
self.skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
283289
self.pairwise_tokenization = pairwise_tokenization
284290
self.multichoice_continuations_start_space = multichoice_continuations_start_space
285291

@@ -396,6 +402,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
396402
revision=revision,
397403
max_memory=max_memory,
398404
device_map=device_map,
405+
# tp_plan="auto",
399406
torch_dtype=torch_dtype,
400407
trust_remote_code=self.config.trust_remote_code,
401408
**kwargs,
@@ -595,7 +602,9 @@ def _continuous_greedy_until(
595602
# for output in _output.outputs:
596603
output_token_ids.append(_output.generated_tokens)
597604
# logprobs_raw.append(output.logprobs)
598-
result.append(self.tokenizer.decode(_output.generated_tokens))
605+
result.append(
606+
self.tokenizer.decode(_output.generated_tokens, skip_special_tokens=self.skip_special_tokens)
607+
)
599608

600609
if logprobs_raw and output_token_ids and False:
601610
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]]
@@ -646,7 +655,9 @@ def _padded_greedy_until(
646655
tokenized_context = self.tokenizer(context)
647656

648657
# Longest context in the current split is the first item (since we sort reversed)
649-
longest_context_continuation_size_in_split = len(tokenized_context) + split[0].generation_size
658+
longest_context_continuation_size_in_split = (
659+
len(tokenized_context["input_ids"]) + split[0].generation_size
660+
)
650661
max_context_continuation_size_allowed = min(
651662
longest_context_continuation_size_in_split, self.max_length
652663
)
@@ -669,12 +680,12 @@ def _padded_greedy_until(
669680

670681
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
671682
if self.use_chat_template:
672-
stop_tokens = []
683+
stop_tokens = [self.tokenizer.eos_token]
673684
else:
674685
# NOTE: we are assuming all items in a batch behave similarly (same
675686
# stop_tokens and max_tokens genrated) which is not necessarily
676687
# the case! Because of that we only use batch size of 1
677-
stop_tokens = batch[0].stop_sequences
688+
stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences
678689

679690
max_new_tokens = batch[0].generation_size
680691
num_samples = batch[0].num_samples
@@ -1189,6 +1200,9 @@ def pad_and_gather(
11891200
output_tensor = self.accelerator.gather(output_tensor)
11901201
return output_tensor, length_tensor
11911202

1203+
def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
1204+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=self.skip_special_tokens)
1205+
11921206

11931207
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
11941208
"""Criteria to stop on the specified multi-token sequence."""

0 commit comments

Comments
 (0)