Skip to content

Commit a3c5b36

Browse files
ngxsonkashif
andauthored
🚩 setup_chat_format: throw error if there is already a template in base model (huggingface#2252)
* setup_chat_format: throw error if there was already a template * fix lint * clarify in docs * fix test? --------- Co-authored-by: Kashif Rasul <[email protected]>
1 parent 755242d commit a3c5b36

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

‎tests/test_dataset_formatting.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ class SetupChatFormatTestCase(unittest.TestCase):
119119
def setUp(self):
120120
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
121121
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
122+
# remove built-in chat_template to simulate a model having no chat_template
123+
self.tokenizer.chat_template = None
122124

123125
def test_setup_chat_format(self):
124126
original_tokenizer_len = len(self.tokenizer)

‎trl/models/utils.py‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def setup_chat_format(
8484
"""
8585
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
8686
87+
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`.
88+
8789
Args:
8890
model (`~transformers.PreTrainedModel`): The model to be modified.
8991
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
@@ -94,6 +96,12 @@ def setup_chat_format(
9496
model (`~transformers.PreTrainedModel`): The modified model.
9597
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
9698
"""
99+
# check if model already had a chat template
100+
if tokenizer.chat_template is not None:
101+
raise ValueError(
102+
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None"
103+
)
104+
97105
# check if format available and retrieve
98106
if format not in FORMAT_MAPPING:
99107
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")

0 commit comments

Comments
 (0)