Skip to content

Commit aab21eb

Browse files
Include chat_template_kwargs in apply_chat_template (#4233)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent b997a31 commit aab21eb

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

tests/test_data_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,29 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example):
396396
assert isinstance(result["label"], bool)
397397
assert result["label"] == example["label"]
398398

399+
def test_apply_chat_template_with_chat_template_kwargs(self):
400+
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
401+
402+
example = {
403+
"prompt": [{"role": "user", "content": "What color is the sky?"}],
404+
# with this tokenizer, when you pass enable_thinking=False, it will add "<think>\n\n</think>\n\n"
405+
"chat_template_kwargs": {"enable_thinking": False},
406+
}
407+
result = apply_chat_template(example, tokenizer)
408+
409+
# docstyle-ignore
410+
expected = textwrap.dedent("""\
411+
<|im_start|>user
412+
What color is the sky?<|im_end|>
413+
<|im_start|>assistant
414+
<think>
415+
416+
</think>
417+
418+
""")
419+
420+
assert result["prompt"] == expected
421+
399422
def test_apply_chat_template_with_tools(self):
400423
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
401424

trl/data_utils.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,13 @@ def apply_chat_template(
143143

144144
# Apply the chat template to the whole conversation
145145
if "messages" in example:
146-
messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs)
146+
messages = tokenizer.apply_chat_template(
147+
example["messages"],
148+
tools=tools,
149+
tokenize=False,
150+
**example.get("chat_template_kwargs", {}),
151+
**template_kwargs,
152+
)
147153

148154
# Apply the chat template to the prompt, adding the generation prompt
149155
if "prompt" in example:
@@ -162,14 +168,19 @@ def apply_chat_template(
162168
continue_final_message=continue_final_message,
163169
tokenize=False,
164170
add_generation_prompt=add_generation_prompt,
171+
**example.get("chat_template_kwargs", {}),
165172
**template_kwargs,
166173
)
167174

168175
# Apply the chat template to the entire prompt + completion
169176
if "prompt" in example: # explicit prompt and prompt-completion case
170177
if "chosen" in example:
171178
prompt_chosen = tokenizer.apply_chat_template(
172-
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
179+
example["prompt"] + example["chosen"],
180+
tools=tools,
181+
tokenize=False,
182+
**example.get("chat_template_kwargs", {}),
183+
**template_kwargs,
173184
)
174185
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
175186
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
@@ -179,24 +190,42 @@ def apply_chat_template(
179190
chosen = prompt_chosen[len(prompt) :]
180191
if "rejected" in example and "prompt" in example: # explicit prompt
181192
prompt_rejected = tokenizer.apply_chat_template(
182-
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
193+
example["prompt"] + example["rejected"],
194+
tools=tools,
195+
tokenize=False,
196+
**example.get("chat_template_kwargs", {}),
197+
**template_kwargs,
183198
)
184199
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
185200
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
186201
rejected = prompt_rejected[len(prompt) :]
187202
if "completion" in example:
188203
prompt_completion = tokenizer.apply_chat_template(
189-
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
204+
example["prompt"] + example["completion"],
205+
tools=tools,
206+
tokenize=False,
207+
**example.get("chat_template_kwargs", {}),
208+
**template_kwargs,
190209
)
191210
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
192211
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
193212
completion = prompt_completion[len(prompt) :]
194213
else: # implicit prompt case
195214
if "chosen" in example:
196-
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs)
215+
chosen = tokenizer.apply_chat_template(
216+
example["chosen"],
217+
tools=tools,
218+
tokenize=False,
219+
**example.get("chat_template_kwargs", {}),
220+
**template_kwargs,
221+
)
197222
if "rejected" in example:
198223
rejected = tokenizer.apply_chat_template(
199-
example["rejected"], tools=tools, tokenize=False, **template_kwargs
224+
example["rejected"],
225+
tools=tools,
226+
tokenize=False,
227+
**example.get("chat_template_kwargs", {}),
228+
**template_kwargs,
200229
)
201230

202231
# Extract the completion by removing the prompt part from the prompt-completion string
@@ -239,7 +268,9 @@ def maybe_apply_chat_template(
239268
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.
240269
241270
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
242-
messages, where each message is a dictionary with keys `"role"` and `"content"`.
271+
messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
272+
may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
273+
to the chat template renderer.
243274
tokenizer (`PreTrainedTokenizerBase`):
244275
Tokenizer to apply the chat template with.
245276
tools (`list[Union[dict, Callable]]`, *optional*):

0 commit comments

Comments
 (0)