diff --git a/docs/sphinx_doc/assets/toolace_reward_curve.png b/docs/sphinx_doc/assets/toolace_reward_curve.png
new file mode 100644
index 0000000000..c062466783
Binary files /dev/null and b/docs/sphinx_doc/assets/toolace_reward_curve.png differ
diff --git a/examples/grpo_toolcall/README.md b/examples/grpo_toolcall/README.md
new file mode 100644
index 0000000000..3418178b05
--- /dev/null
+++ b/examples/grpo_toolcall/README.md
@@ -0,0 +1,17 @@
+# GRPO on ToolAce dataset
+
+This example shows the usage of GRPO on the [ToolAce](https://huggingface.co/datasets/Team-ACE/ToolACE) dataset.
+
+We reference code from [Tool-N1](https://github.com/NVlabs/Tool-N1) for the data preprocessing script and the workflow construction.
+
+The config files are located in [`toolace.yaml`](toolace.yaml) and [`train_toolace.yaml`](train_toolace.yaml).
+
+
+## How to run
+To preprocess the data into the format required by our `toolcall_workflow`, run the following command: `python scripts/data_prepare/get_toolace_data.py`.
+
+Then fill in the config file `toolace.yaml` and run the following command: `trinity run --config examples/grpo_toolcall/toolace.yaml`.
+
+## Reward curve results
+
+
diff --git a/examples/grpo_toolcall/toolace.yaml b/examples/grpo_toolcall/toolace.yaml
new file mode 100644
index 0000000000..1923bb032b
--- /dev/null
+++ b/examples/grpo_toolcall/toolace.yaml
@@ -0,0 +1,55 @@
+project: "Trinity-RFT-toolace"
+name: "qwen2.5-7B-toolace"
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 8
+
+model:
+ model_path: /PATH/TO/MODEL/
+ max_prompt_tokens: 4096
+ max_response_tokens: 8192
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 1
+ batch_size: 128
+ max_retry_times: 3
+ max_retry_interval: 1
+ explorer_input:
+ taskset:
+ name: toolace_data
+ storage_type: file
+ path: scripts/data_prepare/toolace_data
+ # format: []
+ rollout_args:
+ n: 8
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets: []
+ default_workflow_type: 'toolcall_workflow'
+ trainer_input:
+ experience_buffer:
+ name: toolace_buffer
+ storage_type: queue
+ path: 'sqlite:///toolace.db'
+explorer:
+ eval_interval: 50
+ runner_num: 32
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: true
+ dtype: bfloat16
+ seed: 42
+synchronizer:
+ sync_method: 'nccl'
+ sync_interval: 1
+ sync_timeout: 3600
+trainer:
+ trainer_type: 'verl'
+ trainer_config_path: 'examples/grpo_toolcall/train_toolace.yaml'
+ save_interval: 100
diff --git a/examples/grpo_toolcall/train_toolace.yaml b/examples/grpo_toolcall/train_toolace.yaml
new file mode 100644
index 0000000000..19eb75361b
--- /dev/null
+++ b/examples/grpo_toolcall/train_toolace.yaml
@@ -0,0 +1,49 @@
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: True # False
+ actor:
+ strategy: fsdp # This is for backward-compatibility
+ ppo_micro_batch_size_per_gpu: 1
+ use_dynamic_bsz: True # False
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
+ grad_clip: 1.0
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 2 # sp size
+ optim:
+ lr: 1e-6
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ fsdp_config:
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ log_prob_micro_batch_size_per_gpu: 1
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
+
+trainer:
+ balance_batch: True
+ # total_training_steps: null
+ # auto: find the last ckpt to resume. If can't find, start from scratch
+ resume_mode: auto # or auto or resume_path
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: False
+ del_local_ckpt_after_load: False
+ val_before_train: False
diff --git a/scripts/data_prepare/get_toolace_data.py b/scripts/data_prepare/get_toolace_data.py
new file mode 100644
index 0000000000..813190b8c5
--- /dev/null
+++ b/scripts/data_prepare/get_toolace_data.py
@@ -0,0 +1,276 @@
+"""
+We use this script to create the toolace dataset files.
+This script is adapted from https://github.com/NVlabs/Tool-N1/
+The original license and copyright notice are retained below.
+#
+# =================================================================================
+#
+# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+"""
+
+import ast
+import json
+import os
+import random
+import re
+
+from datasets import Dataset, load_dataset
+
+random.seed(42)
+
+DEFAULT_PROMPT = """You are an expert in composing functions. You are given a question and a set of possible functions. \nBased on the question, you will need to make one or more function/tool calls to achieve the purpose. \nIf none of the function can be used, point it out. If the given question lacks the parameters required by the function,\nalso point it out. You should only return the function call in tools call sections. Here is a list of functions in JSON format that you can invoke:
+"""
+
+
+def _split_top_level_args(params_str):
+ items = []
+ current = []
+
+ bracket_level = 0
+ brace_level = 0
+ paren_level = 0
+
+ quote_char = None
+
+ for char in params_str:
+ if quote_char:
+ if char == quote_char:
+ quote_char = None
+ current.append(char)
+ else:
+ current.append(char)
+ continue
+ else:
+ if char in ["'", '"']:
+ quote_char = char
+ current.append(char)
+ continue
+
+ if char == "(":
+ paren_level += 1
+ current.append(char)
+ elif char == ")":
+ paren_level -= 1
+ current.append(char)
+ elif char == "[":
+ bracket_level += 1
+ current.append(char)
+ elif char == "]":
+ bracket_level -= 1
+ current.append(char)
+ elif char == "{":
+ brace_level += 1
+ current.append(char)
+ elif char == "}":
+ brace_level -= 1
+ current.append(char)
+ elif char == "," and bracket_level == 0 and brace_level == 0 and paren_level == 0:
+ items.append("".join(current).strip())
+ current = []
+ else:
+ current.append(char)
+
+ if current:
+ items.append("".join(current).strip())
+
+ return items
+
+
+def _parse_params(params_str):
+ result = {}
+ pairs = _split_top_level_args(params_str)
+
+ for item in pairs:
+ if "=" in item:
+ key, val = item.split("=", 1)
+ key = key.strip()
+ val = val.strip()
+ try:
+ parsed_val = ast.literal_eval(val)
+ except Exception:
+ parsed_val = val.strip("'\"")
+ result[key] = parsed_val
+ return result
+
+
+def _parse_function_string(function_string):
+ function_pattern = r"([a-zA-Z0-9\._\?\|\s\-]+)\((.*?)\)"
+ matches = re.findall(function_pattern, function_string)
+ parsed_functions = []
+ for func_name, params_str in matches:
+ func_name = func_name.strip()
+ params_dict = _parse_params(params_str)
+ parsed_functions.append((func_name, params_dict))
+
+ return parsed_functions
+
+
+def _extract_functions_from_system(text):
+ pattern = r"Here is a list of functions in JSON format that you can invoke:\n(.*?)\nShould you decide to return the function call\(s\)."
+ match = re.search(pattern, text, re.DOTALL) # re.DOTALL allows '.' to match newlines
+ if match:
+ s = match.group(1).strip()
+ s = s[:-2] + "]"
+ s = json.loads(s)
+ return s
+ else:
+ return None
+
+
+def _validate_function_format(s):
+ if not s or not isinstance(s, str):
+ return False
+
+ pattern = (
+ r"^\[\s*([a-zA-Z0-9\._\?\|\s\-/]+\(.*?\)\s*)(,\s*[a-zA-Z0-9\._\?\|\s\-/]+\(.*?\)\s*)*\]$"
+ )
+ return bool(re.match(pattern, s.strip()))
+
+
+def _process_toolace(data, include_nocall=False):
+ filtered = []
+ for d in data:
+ tools = _extract_functions_from_system(d["system"])
+ # print(tools)
+ if not tools:
+ continue
+ item = {"tools": json.dumps(tools)}
+ marker = "JSON format that you can invoke:\n"
+ idx = d["system"].find(marker)
+ item["system"] = d["system"][: idx + len(marker)]
+ convs = []
+ for c in d["conversations"]:
+ # if c["from"] == "gpt" and c["value"].startswith("[") and c["value"].endswith("]"):
+ if c["from"] == "assistant" and c["value"].startswith("[") and c["value"].endswith("]"):
+ funcs = _parse_function_string(c["value"])
+ if funcs and _validate_function_format(c["value"]):
+ c["from"] = "function_call"
+ c["value"] = json.dumps([{"name": f[0], "arguments": f[1]} for f in funcs])
+ convs.append(c)
+ item["conversations"] = convs
+ filtered.append(item)
+
+ print("filtered len")
+ print(len(filtered))
+
+ results = []
+ for d in filtered:
+ cs, sys, tools = d["conversations"], d["system"], d["tools"]
+ for i, c in enumerate(cs):
+ if (
+ c["from"] == "function_call"
+ and c["value"].startswith("[")
+ and c["value"].endswith("]")
+ ):
+ results.append(
+ {
+ "system": sys,
+ "conversations": cs[:i],
+ "answer": json.loads(c["value"]),
+ "tools": tools,
+ }
+ )
+ if c["from"] == "assistant" and include_nocall:
+ results.append(
+ {"system": sys, "conversations": cs[:i], "answer": [], "tools": tools}
+ )
+
+ results = [r for r in results if r["tools"]]
+ final = []
+ for r in results:
+ if len(r["conversations"]) >= 2:
+ id = "toolace_multiple_turn"
+ else:
+ id = "toolace_single_turn"
+ out = {
+ "raw_system": r["system"],
+ "tools": json.loads(r["tools"]),
+ "conversations": r["conversations"],
+ "answer": r["answer"],
+ "id": id,
+ }
+ for c in out["conversations"]:
+ if (
+ c["from"] == "function_call"
+ and c["value"].startswith("[")
+ and c["value"].endswith("]")
+ ):
+ c["from"] = "function_call"
+ final.append(out)
+
+ return [
+ f
+ for f in final
+ if all(ans["name"] in {tool["name"] for tool in f["tools"]} for ans in f["answer"])
+ ]
+
+
+def pre_process_data(data_names, include_nocall=False):
+ # toolace - single turn
+ if data_names == "toolace_single_turn":
+ dataset_dict = load_dataset("Team-ACE/ToolACE", split="train").to_list()
+ print("raw_toolace_single_turn", len(dataset_dict))
+ result = _process_toolace(dataset_dict)
+ else:
+ raise ValueError(
+ f"Error: The dataset '{data_names}' is not supported. Please check the dataset name or implement support for it."
+ )
+
+ return result
+
+
+def dict2hg(data):
+ hg_data = {"tools": [], "conversations": [], "answer": [], "raw_system": []}
+ for item in data:
+ tool = json.dumps(item["tools"])
+ hg_data["tools"].append(tool)
+ hg_data["conversations"].append(item["conversations"])
+ answer = json.dumps(item["answer"])
+ hg_data["answer"].append(answer)
+ hg_data["raw_system"].append(item["raw_system"])
+ hg_data = {
+ "tools": hg_data["tools"][:],
+ "conversations": hg_data["conversations"][:],
+ "answer": hg_data["answer"][:],
+ "raw_system": hg_data["raw_system"][:],
+ }
+ hg_data = Dataset.from_dict(hg_data)
+ return hg_data
+
+
+def create_dataset_files(output_dir):
+ # make the output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ data_sum = []
+ single_turn_toolace = pre_process_data("toolace_single_turn")
+ print("toolace_single_turn", len(single_turn_toolace))
+ data_sum.extend(single_turn_toolace)
+
+ hg_data = dict2hg(data_sum)
+ list_new_data = hg_data.to_list()
+ # randomly shuffle
+ random.shuffle(list_new_data)
+
+ with open(os.path.join(output_dir, "toolace_clean_data.json"), "w", encoding="utf-8") as f:
+ json.dump(list_new_data, f, indent=4, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+ current_file_path = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_path}/toolace_data"
+ create_dataset_files(output_dir)
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 496996a05d..b2d418126c 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""Workflow module"""
from .customized_math_workflows import MathBoxedWorkflow
+from .customized_toolcall_workflows import ToolCallWorkflow
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
@@ -18,4 +19,5 @@
"SciWorldWorkflow",
"MathBoxedWorkflow",
"MathRMWorkflow",
+ "ToolCallWorkflow",
]
diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py
new file mode 100644
index 0000000000..a3a8a83554
--- /dev/null
+++ b/trinity/common/workflows/customized_toolcall_workflows.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+"""
+We include the customized toolcall workflows in this file.
+Code adapted from https://github.com/NVlabs/Tool-N1
+Reference Paper https://arxiv.org/pdf/2505.00024 for further details.
+"""
+
+import json
+import re
+from collections import Counter
+from dataclasses import asdict
+from typing import List
+
+from trinity.common.experience import Experience
+from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
+from trinity.utils.log import get_logger
+
+logger = get_logger(__name__)
+
+# Adapted from https://github.com/NVlabs/Tool-N1
+qwen_tool_prompts = """# Tool
+
+
+{tools}
+
+
+In each action step, you MUST:
+1. Think about the reasoning process in the mind and enclosed your reasoning within XML tags.
+2. Then, provide a json object with function names and arguments within XML tags. i.e., [{{"name": , "arguments": }}, {{"name": , "arguments": }}, ...]
+3. Make sure both the reasoning and the tool call steps are included together in one single reply.
+A complete reply example is: To address the query, I need to send the email to Bob and then buy the banana through walmart. [{{"name": "email", "arguments": {{"receiver": "Bob", "content": "I will bug banana through walmart"}}}}, {{"name": "walmart", "arguments": {{"input": "banana"}}}}]. Please make sure the type of the arguments is correct.
+"""
+
+
+# Adapted from https://github.com/NVlabs/Tool-N1
+def construct_prompt(dp):
+ def format_tools(tools):
+ tools = json.loads(tools)
+ string = ""
+ for tool in tools:
+ string += json.dumps({"type": "function", "function": tool}) + "\n"
+ if string[-1] == "\n":
+ string = string[:-1]
+ return string
+
+ tools = format_tools(dp["tools"])
+ tool_prompt = qwen_tool_prompts.format(tools=tools)
+ system = dp["raw_system"]
+ conversations = dp["conversations"]
+ prompt = []
+ prompt.append({"role": "system", "content": system + tool_prompt})
+ for tem in conversations:
+ if tem["from"] == "human" or tem["from"] == "user":
+ prompt.append({"role": "user", "content": tem["value"]})
+ elif tem["from"] == "gpt" or tem["from"] == "assistant":
+ prompt.append({"role": "assistant", "content": tem["value"]})
+ elif tem["from"] == "observation" or tem["from"] == "tool":
+ prompt.append({"role": "tool", "content": tem["value"]})
+ elif tem["from"] == "function_call":
+ prompt.append({"role": "assistant", "content": json.dumps(tem["value"])})
+ return prompt
+
+
+# Adapted from https://github.com/NVlabs/Tool-N1
+def validate_result(result, answer):
+ if len(result) == 0 or len(answer) == 0:
+ if len(result) == len(answer):
+ return 2
+ else:
+ return 0
+ else:
+ try:
+ counter1_full = Counter(
+ (item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in result
+ )
+ counter2_full = Counter(
+ (item["name"], json.dumps(item["arguments"], sort_keys=True)) for item in answer
+ )
+ except TypeError:
+ return 0
+ if counter1_full == counter2_full:
+ return 2
+
+ counter1_name = Counter(item["name"] for item in result)
+ counter2_name = Counter(item["name"] for item in answer)
+
+ if counter1_name == counter2_name:
+ return 1
+
+ return 0
+
+
+# Adapted from https://github.com/NVlabs/Tool-N1
+def validate_format(tool_call_list):
+ for item in tool_call_list:
+ if not isinstance(item, dict):
+ return 0
+ for item in tool_call_list:
+ if "name" not in item.keys() or "arguments" not in item.keys():
+ return 0
+ return 1
+
+
+# Adapted from https://github.com/NVlabs/Tool-N1
+def extract_solution_v0(tool_call_str):
+ output_string = tool_call_str
+
+ pattern = r"(.*?)"
+ matches = list(re.finditer(pattern, tool_call_str, flags=re.DOTALL))
+ if not matches:
+ return None, output_string
+ last_content = matches[-1].group(1).strip()
+ try:
+ return json.loads(last_content), output_string
+ except json.JSONDecodeError:
+ return None, output_string
+
+
+def compute_score_v0( # noqa: C901
+ solution_str,
+ ground_truth,
+ do_print=False,
+):
+ answer = json.loads(ground_truth)
+
+ result, output_string = extract_solution_v0(solution_str)
+
+ if isinstance(result, str):
+ try:
+ result = json.loads(result)
+ except json.JSONDecodeError:
+ result = None
+
+ if isinstance(result, dict):
+ tem = []
+ tem.append(result)
+ result = tem
+
+ if isinstance(answer, str):
+ answer = json.loads(answer)
+
+ if do_print:
+ print("************solution_str************")
+ print(solution_str)
+ print(f"Extracted result: {result}")
+ print(f"Solution string: {answer}")
+
+ if result is not None:
+ if "" not in output_string or "" not in output_string:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("not thinking:", -1)
+ return 0
+
+ if result is None:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("result is None:", -1)
+ return 0
+
+ # added rule1
+ if solution_str.count("") != 1 or solution_str.count("") != 1:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print(
+ f"Fail, think tag appear not once: "
+ f" appear {solution_str.count('')} times, "
+ f" appear {solution_str.count('')} times",
+ -1,
+ )
+ return 0
+
+ # added rule2
+ think_end_pos = solution_str.find("")
+ tool_call_start_pos = solution_str.find("")
+
+ if tool_call_start_pos != -1:
+ if think_end_pos > tool_call_start_pos:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("Fail: tag must before tag", -1)
+ return 0
+
+ if not validate_format(result):
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("result wrong formate:", -1)
+ return 0
+
+ if validate_result(result, answer) == 2:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("get full core:", 1)
+ return 1
+ else:
+ if do_print:
+ print("--------" * 5 + "\n\n")
+ print("wrong answer", -1)
+ return 0
+
+
+def compute_toolcall_reward(
+ solution_str: str,
+ ground_truth: str,
+) -> float:
+ res = compute_score_v0(solution_str, ground_truth)
+ if isinstance(res, (int, float, bool)):
+ return float(res)
+ else:
+ return float(res[0])
+
+
+@WORKFLOWS.register_module("toolcall_workflow")
+class ToolCallWorkflow(SimpleWorkflow):
+ """
+ A workflow for toolcall tasks.
+ Prompt construction and reward function from https://github.com/NVlabs/Tool-N1
+
+ Only support qwen model for now. You can change the prompt construction and reward calculation by yourself for other models.
+ """
+
+ def reset(self, task: Task):
+ self.format_args = task.format_args
+ self.system_prompt = task.format_args.system_prompt
+ self.reply_prefix = task.format_args.reply_prefix
+
+ self.raw_task = task.raw_task
+ self.task_desc = task.task_desc
+ self.truth = task.truth
+
+ # Rollout args
+ rollout_args = asdict(task.rollout_args)
+ self.rollout_args = rollout_args
+ self.is_eval = task.is_eval
+
+ self.workflow_args = task.workflow_args
+ self.reward_fn_args = task.reward_fn_args
+
+ def format_prompt(self):
+ raw_task = self.raw_task
+ messages = construct_prompt(raw_task)
+ return messages
+
+ def run(self) -> List[Experience]:
+ messages = self.format_prompt()
+
+ logger.debug("start chat")
+ responses = self.model.chat(messages, **self.rollout_args)
+
+ for i, response in enumerate(responses):
+ reward = 0.0
+
+ if self.raw_task is not None:
+ ground_truth = self.raw_task.get("answer")
+ if ground_truth is not None:
+ reward = compute_toolcall_reward(
+ solution_str=response.response_text,
+ ground_truth=ground_truth,
+ )
+ else:
+ logger.error(
+ "Key 'answer' not found in self.raw_task. Assigning default reward."
+ )
+ else:
+ logger.error("self.raw_task is None. Assigning default reward.")
+ logger.debug(
+ f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
+ )
+ response.reward = reward
+ response.eid.run = i
+ return responses