1515from typing import Callable
1616
1717from datasets import Dataset , load_dataset
18- from transformers import AutoModelForCausalLM , AutoTokenizer
18+ from transformers import AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
1919
2020from trl .extras .dataset_formatting import get_formatting_func_from_dataset
2121from trl .models .utils import ChatMlSpecialTokens , clone_chat_template , setup_chat_format
@@ -159,47 +159,59 @@ def test_example_with_setup_model(self):
159159
160160
161161class CloneChatTemplateTestCase (TrlTestCase ):
162- def setUp (self ):
163- super ().setUp ()
162+ def test_clone (self ):
164163 # This tokenizer doesn't have a chat_template by default
165- self . tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
166- self . model = AutoModelForCausalLM .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
164+ tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
165+ model = AutoModelForCausalLM .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
167166 # This one has a chat_template by default
168- self .source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
169-
170- def test_clone (self ):
171- _ , modified_tokenizer , _ = clone_chat_template (self .model , self .tokenizer , self .source )
167+ source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
168+ _ , modified_tokenizer , _ = clone_chat_template (model , tokenizer , source )
172169
173170 # Check if special tokens are correctly set
174171 self .assertEqual (modified_tokenizer .eos_token , "<|im_end|>" )
175172
176173 def test_clone_with_resize (self ):
174+ # This tokenizer doesn't have a chat_template by default
175+ tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
176+ model = AutoModelForCausalLM .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
177+ # This one has a chat_template by default
178+ source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
177179 modified_model , modified_tokenizer , _ = clone_chat_template (
178- self . model , self . tokenizer , self . source , resize_to_multiple_of = 123
180+ model , tokenizer , source , resize_to_multiple_of = 123
179181 )
180182
181183 # Check that the input embeddings have been resized to a multiple of 123
182184 self .assertEqual ((modified_model .vocab_size % 123 ), 0 )
183185 # Check that the input embeddings size matches the tokenizer vocabulary size
184- self .assertEqual (self . model .vocab_size , len (modified_tokenizer .vocab ))
186+ self .assertEqual (model .vocab_size , len (modified_tokenizer .vocab ))
185187
186188 def test_clone_with_resize_and_extra_tokens_already_in_vocab (self ):
189+ # This tokenizer doesn't have a chat_template by default
190+ tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
191+ model = AutoModelForCausalLM .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
192+ # This one has a chat_template by default
193+ source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
187194 # This will add <extra_id_0>, <extra_id_1>, ... to the tokenizer
188195 modified_model , modified_tokenizer , _ = clone_chat_template (
189- self . model , self . tokenizer , self . source , resize_to_multiple_of = 123
196+ model , tokenizer , source , resize_to_multiple_of = 123
190197 )
191198 # Try if we can resize a tokenizer that already has extra these extra tokens
192199 modified_model , modified_tokenizer , _ = clone_chat_template (
193- modified_model , modified_tokenizer , self . source , resize_to_multiple_of = 124
200+ modified_model , modified_tokenizer , source , resize_to_multiple_of = 124
194201 )
195202
196203 # Check that the input embeddings have been resized to a multiple of 123
197204 self .assertEqual ((modified_model .vocab_size % 124 ), 0 )
198205 # Check that the input embeddings size matches the tokenizer vocabulary size
199- self .assertEqual (self . model .vocab_size , len (modified_tokenizer .vocab ))
206+ self .assertEqual (model .vocab_size , len (modified_tokenizer .vocab ))
200207
201208 def test_apply_new_chat_template (self ):
202- _ , modified_tokenizer , _ = clone_chat_template (self .model , self .tokenizer , self .source )
209+ # This tokenizer doesn't have a chat_template by default
210+ tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
211+ model = AutoModelForCausalLM .from_pretrained ("trl-internal-testing/tiny-BloomForCausalLM" )
212+ # This one has a chat_template by default
213+ source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
214+ _ , modified_tokenizer , _ = clone_chat_template (model , tokenizer , source )
203215 messages = [
204216 {"role" : "system" , "content" : "You are helpful" },
205217 {"role" : "user" , "content" : "Hello" },
@@ -211,3 +223,16 @@ def test_apply_new_chat_template(self):
211223 prompt ,
212224 "<|im_start|>system\n You are helpful<|im_end|>\n <|im_start|>user\n Hello<|im_end|>\n <|im_start|>assistant\n <think>\n \n </think>\n \n Hi, how can I help you?<|im_end|>\n " ,
213225 )
226+
227+ def test_clone_with_sequence_classification_model (self ):
228+ # This tokenizer doesn't have a chat_template by default
229+ tokenizer = AutoTokenizer .from_pretrained ("trl-internal-testing/tiny-GptNeoXForSequenceClassification" )
230+ model = AutoModelForSequenceClassification .from_pretrained (
231+ "trl-internal-testing/tiny-GptNeoXForSequenceClassification"
232+ )
233+ # This one has a chat_template by default
234+ source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
235+ _ , modified_tokenizer , _ = clone_chat_template (model , tokenizer , source )
236+
237+ # Check if special tokens are correctly set
238+ self .assertEqual (modified_tokenizer .eos_token , "<|im_end|>" )
0 commit comments