Skip to content

Commit 70e2017

Browse files
authored
🎞️ Support sequence classification models in clone_chat_template (#4097)
1 parent 4368f54 commit 70e2017

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

tests/test_dataset_formatting.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Callable
1616

1717
from datasets import Dataset, load_dataset
18-
from transformers import AutoModelForCausalLM, AutoTokenizer
18+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
1919

2020
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
2121
from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format
@@ -159,47 +159,59 @@ def test_example_with_setup_model(self):
159159

160160

161161
class CloneChatTemplateTestCase(TrlTestCase):
162-
def setUp(self):
163-
super().setUp()
162+
def test_clone(self):
164163
# This tokenizer doesn't have a chat_template by default
165-
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
166-
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
164+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
165+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
167166
# This one has a chat_template by default
168-
self.source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
169-
170-
def test_clone(self):
171-
_, modified_tokenizer, _ = clone_chat_template(self.model, self.tokenizer, self.source)
167+
source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
168+
_, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
172169

173170
# Check if special tokens are correctly set
174171
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
175172

176173
def test_clone_with_resize(self):
174+
# This tokenizer doesn't have a chat_template by default
175+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
176+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
177+
# This one has a chat_template by default
178+
source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
177179
modified_model, modified_tokenizer, _ = clone_chat_template(
178-
self.model, self.tokenizer, self.source, resize_to_multiple_of=123
180+
model, tokenizer, source, resize_to_multiple_of=123
179181
)
180182

181183
# Check that the input embeddings have been resized to a multiple of 123
182184
self.assertEqual((modified_model.vocab_size % 123), 0)
183185
# Check that the input embeddings size matches the tokenizer vocabulary size
184-
self.assertEqual(self.model.vocab_size, len(modified_tokenizer.vocab))
186+
self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab))
185187

186188
def test_clone_with_resize_and_extra_tokens_already_in_vocab(self):
189+
# This tokenizer doesn't have a chat_template by default
190+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
191+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
192+
# This one has a chat_template by default
193+
source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
187194
# This will add <extra_id_0>, <extra_id_1>, ... to the tokenizer
188195
modified_model, modified_tokenizer, _ = clone_chat_template(
189-
self.model, self.tokenizer, self.source, resize_to_multiple_of=123
196+
model, tokenizer, source, resize_to_multiple_of=123
190197
)
191198
# Try if we can resize a tokenizer that already has extra these extra tokens
192199
modified_model, modified_tokenizer, _ = clone_chat_template(
193-
modified_model, modified_tokenizer, self.source, resize_to_multiple_of=124
200+
modified_model, modified_tokenizer, source, resize_to_multiple_of=124
194201
)
195202

196203
# Check that the input embeddings have been resized to a multiple of 123
197204
self.assertEqual((modified_model.vocab_size % 124), 0)
198205
# Check that the input embeddings size matches the tokenizer vocabulary size
199-
self.assertEqual(self.model.vocab_size, len(modified_tokenizer.vocab))
206+
self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab))
200207

201208
def test_apply_new_chat_template(self):
202-
_, modified_tokenizer, _ = clone_chat_template(self.model, self.tokenizer, self.source)
209+
# This tokenizer doesn't have a chat_template by default
210+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
211+
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
212+
# This one has a chat_template by default
213+
source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
214+
_, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
203215
messages = [
204216
{"role": "system", "content": "You are helpful"},
205217
{"role": "user", "content": "Hello"},
@@ -211,3 +223,16 @@ def test_apply_new_chat_template(self):
211223
prompt,
212224
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nHi, how can I help you?<|im_end|>\n",
213225
)
226+
227+
def test_clone_with_sequence_classification_model(self):
228+
# This tokenizer doesn't have a chat_template by default
229+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptNeoXForSequenceClassification")
230+
model = AutoModelForSequenceClassification.from_pretrained(
231+
"trl-internal-testing/tiny-GptNeoXForSequenceClassification"
232+
)
233+
# This one has a chat_template by default
234+
source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
235+
_, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
236+
237+
# Check if special tokens are correctly set
238+
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")

trl/models/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def clone_chat_template(
223223
# Set the EOS token from the source tokenizer (important for generation)
224224
tokenizer.eos_token = tokenizer_source.eos_token
225225
model.config.eos_token_id = tokenizer.eos_token_id
226-
model.generation_config.eos_token_id = tokenizer.eos_token_id
226+
if model.generation_config is not None: # for SequenceClassification models, generation_config is None
227+
model.generation_config.eos_token_id = tokenizer.eos_token_id
227228

228229
# Resize model embeddings to include any new tokens, optionally rounding up to a multiple
229230
model.resize_token_embeddings(

0 commit comments

Comments
 (0)