Skip to content

Commit f818188

Browse files
authored
Update tool handling to support JSON string schemas in trainers (#5118)
1 parent 9fc9a7d commit f818188

File tree

8 files changed

+113
-43
lines changed

8 files changed

+113
-43
lines changed

docs/source/dataset_formats.md

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ When preparing datasets for Supervised Fine-Tuning (SFT) with tool calling, it i
159159
The tools must be specified in a codified JSON schema format. You can automatically generate this schema from Python function signatures using the [`~transformers.utils.get_json_schema`] utility:
160160

161161
```python
162+
import json
162163
from transformers.utils import get_json_schema
163164

164165
def control_light(room: str, state: str) -> str:
@@ -175,38 +176,27 @@ def control_light(room: str, state: str) -> str:
175176
return f"The lights in {room} are now {state}."
176177

177178
# Generate JSON schema
178-
json_schema = get_json_schema(control_light)
179+
json_schema = json.dumps([get_json_schema(control_light)])
179180
```
180181

181182
The generated schema would look like:
182183

183184
```python
184-
{
185-
"type": "function",
186-
"function": {
187-
"name": "control_light",
188-
"description": "Controls the lights in a room.",
189-
"parameters": {
190-
"type": "object",
191-
"properties": {
192-
"room": {"type": "string", "description": "The name of the room."},
193-
"state": {"type": "string", "description": 'The desired state of the light ("on" or "off").'},
194-
},
195-
"required": ["room", "state"],
196-
},
197-
"return": {"type": "string", "description": "str: A message indicating the new state of the lights."},
198-
},
199-
}
185+
'[{"type": "function", "function": {"name": "control_light", "description": "Controls the lights in a room.", "parameters": {"type": "object", "properties": {"room": {"type": "string", "description": "The name of the room."}, "state": {"type": "string", "description": "The desired state of the light (\\"on\\" or \\"off\\")."}}, "required": ["room", "state"]}, "return": {"type": "string", "description": "str: A message indicating the new state of the lights."}}}]'
200186
```
201187

202188
A complete dataset entry for SFT might look like:
203189

204190
```python
205-
{"messages": messages, "tools": [json_schema]}
191+
{"messages": messages, "tools": json_schema}
206192
```
207193

208194
For more detailed information on tool calling, refer to the [Tool Calling section in the `transformers` documentation](https://huggingface.co/docs/transformers/chat_extras#tools-and-rag) and the blog post [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
209195

196+
> [!NOTE]
197+
> TRL also accepts `tools` as a Python `list[dict]` (for backward compatibility).
198+
> This is a legacy format and is **not recommended** for new datasets. Prefer storing `tools` as a JSON `str` (with `json.dumps([...])`).
199+
210200
### Harmony
211201

212202
The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include:

docs/source/reward_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ trainer.train()
218218
The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
219219
220220
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
221-
* The list of available tools in the `tools` column, typically provided as JSON schemas
221+
* The list of available tools in the `tools` column, typically provided as JSON `str` schemas
222222
223223
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
224224

docs/source/sft_trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ Alternatively, use the structured conversation format (recommended):
289289
The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
290290
291291
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
292-
* The list of available tools in the `tools` column, typically provided as JSON schemas
292+
* The list of available tools in the `tools` column, typically provided as JSON `str` schemas
293293
294294
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
295295

scripts/generate_toolcall_dataset.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
from dataclasses import dataclass, field
1617

1718
from datasets import Dataset
@@ -212,14 +213,14 @@ def get_wind_conditions(city: str, unit: str) -> tuple[int, str]:
212213
]
213214
],
214215
"tools": [
215-
[start_timer, create_reminder],
216-
[get_current_time],
217-
[get_air_quality_index, get_weather_forecast, get_wind_conditions],
218-
[play_music, control_light],
219-
[get_weather_forecast, get_wind_conditions],
220-
[control_light],
221-
[start_timer, create_reminder],
222-
[get_weather_forecast, get_wind_conditions],
216+
json.dumps([start_timer, create_reminder]),
217+
json.dumps([get_current_time]),
218+
json.dumps([get_air_quality_index, get_weather_forecast, get_wind_conditions]),
219+
json.dumps([play_music, control_light]),
220+
json.dumps([get_weather_forecast, get_wind_conditions]),
221+
json.dumps([control_light]),
222+
json.dumps([start_timer, create_reminder]),
223+
json.dumps([get_weather_forecast, get_wind_conditions]),
223224
]
224225
})
225226
language_modeling_dataset = language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
@@ -318,14 +319,14 @@ def get_wind_conditions(city: str, unit: str) -> tuple[int, str]:
318319
],
319320
],
320321
"tools": [
321-
[start_timer],
322-
[get_current_time],
323-
[get_air_quality_index],
324-
[play_music],
325-
[get_weather_forecast],
326-
[control_light],
327-
[create_reminder],
328-
[get_wind_conditions],
322+
json.dumps([start_timer]),
323+
json.dumps([get_current_time]),
324+
json.dumps([get_air_quality_index]),
325+
json.dumps([play_music]),
326+
json.dumps([get_weather_forecast]),
327+
json.dumps([control_light]),
328+
json.dumps([create_reminder]),
329+
json.dumps([get_wind_conditions]),
329330
],
330331
})
331332
preference_dataset = preference_dataset.train_test_split(test_size=test_size, shuffle=False)

tests/test_reward_trainer.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import pathlib
1617

1718
import pytest
@@ -666,7 +667,41 @@ def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir):
666667

667668
def test_train_toolcall_data(self):
668669
# Get the dataset
669-
dataset = load_dataset("trl-internal-testing/toolcall", "preference", split="train")
670+
dataset = load_dataset("trl-internal-testing/toolcall", "preference", split="train", revision="refs/pr/3")
671+
672+
# Initialize the trainer
673+
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
674+
trainer = RewardTrainer(
675+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
676+
args=training_args,
677+
train_dataset=dataset,
678+
)
679+
680+
# Save the initial parameters to compare them later
681+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
682+
683+
# Train the model
684+
trainer.train()
685+
686+
# Check that the training loss is not None
687+
assert trainer.state.log_history[-1]["train_loss"] is not None
688+
689+
# Check the params have changed
690+
for n, param in previous_trainable_params.items():
691+
new_param = trainer.model.get_parameter(n)
692+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
693+
694+
def test_train_toolcall_data_as_json(self):
695+
# Tabular backends (Arrow/Parquet) can insert `None` for missing keys in nested structures.
696+
# If `tools` is stored as a list of dicts and examples use different dict schemas, nulls may
697+
# be introduced and break tool processing. This test ensures we also support `tools` provided
698+
# as a list of dicts.
699+
dataset = load_dataset("trl-internal-testing/toolcall", "preference", split="train", revision="refs/pr/3")
700+
701+
def convert_to_json(example):
702+
return {"tools": json.loads(example["tools"])}
703+
704+
dataset = dataset.map(convert_to_json)
670705

671706
# Initialize the trainer
672707
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")

tests/test_sft_trainer.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import gc
16+
import json
1617
import pathlib
1718
from unittest.mock import MagicMock, patch
1819

@@ -1312,7 +1313,44 @@ def test_train_with_set_chat_template_from_path(self, lazy_shared_datadir):
13121313

13131314
def test_train_toolcall_data(self):
13141315
# Get the dataset
1315-
dataset = load_dataset("trl-internal-testing/toolcall", "language_modeling", split="train")
1316+
dataset = load_dataset(
1317+
"trl-internal-testing/toolcall", "language_modeling", split="train", revision="refs/pr/2"
1318+
)
1319+
1320+
# Initialize the trainer
1321+
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")
1322+
trainer = SFTTrainer(
1323+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
1324+
)
1325+
1326+
# Save the initial parameters to compare them later
1327+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1328+
1329+
# Train the model
1330+
trainer.train()
1331+
1332+
# Check that the training loss is not None
1333+
assert trainer.state.log_history[-1]["train_loss"] is not None
1334+
1335+
# Check the params have changed
1336+
for n, param in previous_trainable_params.items():
1337+
new_param = trainer.model.get_parameter(n)
1338+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
1339+
1340+
def test_train_toolcall_data_as_json(self):
1341+
# Tabular backends (Arrow/Parquet) can insert `None` for missing keys in nested structures.
1342+
# If `tools` is stored as a list of dicts and examples use different dict schemas, nulls may
1343+
# be introduced and break tool processing. This test ensures we also support `tools` provided
1344+
# as a list of dicts.
1345+
# Get the dataset
1346+
dataset = load_dataset(
1347+
"trl-internal-testing/toolcall", "language_modeling", split="train", revision="refs/pr/2"
1348+
)
1349+
1350+
def convert_to_json(example):
1351+
return {"tools": json.loads(example["tools"])}
1352+
1353+
dataset = dataset.map(convert_to_json)
13161354

13171355
# Initialize the trainer
13181356
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none")

trl/trainer/reward_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import json
1617
import logging
1718
import os
1819
import re
@@ -565,20 +566,22 @@ def add_eos(example, eos_token):
565566
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
566567

567568
def tokenize_fn(example, processing_class):
569+
tools = example.get("tools")
570+
tools = json.loads(tools) if isinstance(tools, str) else tools
568571
if "prompt" in example: # explicit prompt case
569572
example["chosen"] = example["prompt"] + example["chosen"]
570573
example["rejected"] = example["prompt"] + example["rejected"]
571574

572575
if is_conversational(example):
573576
chosen_input_ids = processing_class.apply_chat_template(
574577
example["chosen"],
575-
tools=example.get("tools"),
578+
tools=tools,
576579
return_dict=True,
577580
**example.get("chat_template_kwargs", {}),
578581
)["input_ids"]
579582
rejected_input_ids = processing_class.apply_chat_template(
580583
example["rejected"],
581-
tools=example.get("tools"),
584+
tools=tools,
582585
return_dict=True,
583586
**example.get("chat_template_kwargs", {}),
584587
)["input_ids"]

trl/trainer/sft_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import json
1617
import os
1718
import warnings
1819
from collections import defaultdict
@@ -1016,6 +1017,8 @@ def add_eos(example, eos_token):
10161017
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
10171018

10181019
def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_loss):
1020+
tools = example.get("tools")
1021+
tools = json.loads(tools) if isinstance(tools, str) else tools
10191022
if "prompt" in example: # prompt-completion case
10201023
output = {}
10211024
if is_conversational(example):
@@ -1027,7 +1030,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
10271030
completion = example["completion"]
10281031
prompt_ids = processing_class.apply_chat_template(
10291032
prompt,
1030-
tools=example.get("tools"),
1033+
tools=tools,
10311034
add_generation_prompt=True,
10321035
tokenize=True,
10331036
return_dict=False,
@@ -1038,7 +1041,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
10381041
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
10391042
prompt_completion_processed = processing_class.apply_chat_template(
10401043
prompt + completion,
1041-
tools=example.get("tools"),
1044+
tools=tools,
10421045
tokenize=True,
10431046
return_dict=True,
10441047
return_assistant_tokens_mask=assistant_only_loss,
@@ -1088,7 +1091,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
10881091
messages = example["messages"]
10891092
processed = processing_class.apply_chat_template(
10901093
messages,
1091-
tools=example.get("tools"),
1094+
tools=tools,
10921095
tokenize=True,
10931096
return_dict=True,
10941097
return_assistant_tokens_mask=assistant_only_loss,

0 commit comments

Comments
 (0)