@@ -92,6 +92,8 @@ class TransformersModelConfig(ModelConfig):
92
92
Additional keyword arguments passed to `from_pretrained`. Defaults to empty dict.
93
93
add_special_tokens (bool):
94
94
Whether to add special tokens during tokenization. Defaults to True.
95
+ skip_special_tokens (bool):
96
+ Whether the tokenizer should output special tokens back during generation. Needed for reasoning models. Defaults to True
95
97
model_parallel (bool | None):
96
98
Whether to use model parallelism across multiple GPUs. If None, automatically
97
99
determined based on available GPUs and model size.
@@ -139,6 +141,7 @@ class TransformersModelConfig(ModelConfig):
139
141
max_length : PositiveInt | None = None
140
142
model_loading_kwargs : dict = Field (default_factory = dict )
141
143
add_special_tokens : bool = True
144
+ skip_special_tokens : bool = True
142
145
model_parallel : bool | None = None
143
146
dtype : str | None = None
144
147
device : Union [int , str ] = "cuda"
@@ -187,6 +190,7 @@ def __init__(
187
190
self ._device = self .accelerator .device
188
191
self .multichoice_continuations_start_space = config .multichoice_continuations_start_space
189
192
self ._add_special_tokens = config .add_special_tokens or False
193
+ self .skip_special_tokens = config .skip_special_tokens or True
190
194
self .pairwise_tokenization = config .pairwise_tokenization
191
195
self .batch_size = config .batch_size
192
196
self .continuous_batching = config .continuous_batching
@@ -244,6 +248,7 @@ def from_model(
244
248
tokenizer_name : str = None , # custom tokenizer
245
249
trust_remote_code : bool = False ,
246
250
add_special_tokens : bool = True ,
251
+ skip_special_tokens : bool = True ,
247
252
pairwise_tokenization : bool = False ,
248
253
multichoice_continuations_start_space : bool = None ,
249
254
):
@@ -280,6 +285,7 @@ def from_model(
280
285
281
286
self .use_chat_template = uses_chat_template (self ._tokenizer )
282
287
self ._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
288
+ self .skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
283
289
self .pairwise_tokenization = pairwise_tokenization
284
290
self .multichoice_continuations_start_space = multichoice_continuations_start_space
285
291
@@ -396,6 +402,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
396
402
revision = revision ,
397
403
max_memory = max_memory ,
398
404
device_map = device_map ,
405
+ # tp_plan="auto",
399
406
torch_dtype = torch_dtype ,
400
407
trust_remote_code = self .config .trust_remote_code ,
401
408
** kwargs ,
@@ -595,7 +602,9 @@ def _continuous_greedy_until(
595
602
# for output in _output.outputs:
596
603
output_token_ids .append (_output .generated_tokens )
597
604
# logprobs_raw.append(output.logprobs)
598
- result .append (self .tokenizer .decode (_output .generated_tokens ))
605
+ result .append (
606
+ self .tokenizer .decode (_output .generated_tokens , skip_special_tokens = self .skip_special_tokens )
607
+ )
599
608
600
609
if logprobs_raw and output_token_ids and False :
601
610
logprobs = [logprobs_raw [0 ][token_id ].logprob for token_id in output_token_ids [0 ]]
@@ -646,7 +655,9 @@ def _padded_greedy_until(
646
655
tokenized_context = self .tokenizer (context )
647
656
648
657
# Longest context in the current split is the first item (since we sort reversed)
649
- longest_context_continuation_size_in_split = len (tokenized_context ) + split [0 ].generation_size
658
+ longest_context_continuation_size_in_split = (
659
+ len (tokenized_context ["input_ids" ]) + split [0 ].generation_size
660
+ )
650
661
max_context_continuation_size_allowed = min (
651
662
longest_context_continuation_size_in_split , self .max_length
652
663
)
@@ -669,12 +680,12 @@ def _padded_greedy_until(
669
680
670
681
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
671
682
if self .use_chat_template :
672
- stop_tokens = []
683
+ stop_tokens = [self . tokenizer . eos_token ]
673
684
else :
674
685
# NOTE: we are assuming all items in a batch behave similarly (same
675
686
# stop_tokens and max_tokens genrated) which is not necessarily
676
687
# the case! Because of that we only use batch size of 1
677
- stop_tokens = batch [0 ].stop_sequences
688
+ stop_tokens = [ self . tokenizer . eos_token ] + batch [0 ].stop_sequences
678
689
679
690
max_new_tokens = batch [0 ].generation_size
680
691
num_samples = batch [0 ].num_samples
@@ -1189,6 +1200,9 @@ def pad_and_gather(
1189
1200
output_tensor = self .accelerator .gather (output_tensor )
1190
1201
return output_tensor , length_tensor
1191
1202
1203
+ def tok_decode (self , tokens : torch .LongTensor ) -> list [str ]:
1204
+ return self .tokenizer .batch_decode (tokens , skip_special_tokens = self .skip_special_tokens )
1205
+
1192
1206
1193
1207
class MultiTokenEOSCriteria (transformers .StoppingCriteria ):
1194
1208
"""Criteria to stop on the specified multi-token sequence."""
0 commit comments