Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/sphinx_doc/assets/toolace_3b_response.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/toolace_3b_rewards.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/toolace_length_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/sphinx_doc/assets/toolace_reward_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ Our dataset follows the format in Huggingface datasets library, so we should cor
Just check the data preparation scripts and run the following command.
```bash
# For ALFworld env
python scripts/data_prepare/get_alfworld_data.py
python examples/grpo_alfworld/get_alfworld_data.py

# For WebShop env
python scripts/data_prepare/get_webshop_data.py
python examples/grpo_webshop/get_webshop_data.py
```

The task is described as an environment instead of a single prompt.
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ buffer:
taskset:
name: alfworld
storage_type: file
path: 'scripts/data_prepare/alfworld_data'
path: 'examples/grpo_alfworld/alfworld_data'
format:
prompt_key: 'game_file'
rollout_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):


if __name__ == "__main__":
current_file_path = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_path}/alfworld_data"
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_dir}/alfworld_data"
create_dataset_files(output_dir, train_size=1024, test_size=100)
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def create_dataset_files(output_dir, train_task_names, test_task_names, jar_path
f"JAR file not found at {jar_path}, please set the jar path mannually."
)

current_file_path = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_path}/sciworld_data"
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_dir}/sciworld_data"
train_task_names = [
"boil",
"melt",
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ buffer:
taskset:
name: sciworld
storage_type: file
path: 'scripts/data_prepare/sciworld_data'
path: 'examples/grpo_sciworld/sciworld_data'
format:
prompt_key: 'task_desc'
rollout_args:
Expand Down
66 changes: 65 additions & 1 deletion examples/grpo_toolcall/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,74 @@ The config files are located in [`toolace.yaml`](toolace.yaml) and [`train_toola


## How to run
To preprocess the data into the format required by our `toolcall_workflow`, run the following command: `python scripts/data_prepare/get_toolace_data.py`.
To preprocess the data into the format required by our `toolcall_workflow`, run the following command: `python examples/grpo_toolcall/get_toolace_data.py`.

Then fill in the config file `toolace.yaml` and run the following command: `trinity run --config examples/grpo_toolcall/toolace.yaml`.

## Preventing reward hacking
In our initial experiments on the 3B model, we found that the reward structure of ToolAce will result in reward hacking, so we add modification on the reward structure of the toolcall workflow. (Our workflow design is flexible, and please feel free to modify it according to your needs.)

The original reward for format checking is:
```python

def compute_score_v0(solution_str, ground_truth):
....
if "<think>" not in output_string or "</think>" not in output_string:
return 0
...
```

It simply checks whether the output contains `<think>` and `</think>`.
But it did not check whether there are only one `<think>` and one `</think>` in the output, nor did it check whether the `<think>` and `</think>` are before the `<tool_call>` tags.

This results in reward hacking in 3B model.
While the reward curve seems to converge, the model will generate endless `<think>` and `</think>` tags, resulting in overlength outputs.

<div style="display: flex; justify-content: space-around; align-items: center;">
<img src="../../docs/sphinx_doc/assets/toolace_3b_rewards.png" alt="Reward Curve" style="width: 40%;">
<img src="../../docs/sphinx_doc/assets/toolace_3b_response.png" alt="Response Image" style="width: 40%;">
</div>

The response looks like this:
```
<think>To gather the number of live football events this week, I need to call the Sports Number live events function.</think> <tool_call> [{"name": "Sports Number live events", "arguments": {"sport": "football", "string_range": "this week"}}]</tool_call> <think>The arguments include the sport as football and the string_range as this week to filter the events.</think> <think>This should provide the number of live football events happening this week.</think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> </think> <think> ...
```

To fix this, we add the following code to the `compute_score_v0` function:
```python
...
# added rule1
if solution_str.count("<think>") != 1 or solution_str.count("</think>") != 1:
return 0

# added rule2
think_end_pos = solution_str.find("</think>")
tool_call_start_pos = solution_str.find("<tool_call>")

if tool_call_start_pos != -1 and think_end_pos > tool_call_start_pos:
return 0
...
```

With this fix on reward hacking, the training will successfully converges. With Qwen2.5-7B-Instruct model, it takes around 3 hours on 8 H20 GPUs to train 9.6k data for one epoch.

## Reward curve results

Below is the reward curve of the trained Qwen2.5-7B-Instruct model.
![](../../docs/sphinx_doc/assets/toolace_reward_curve.png)

The response length is also steady.
![](../../docs/sphinx_doc/assets/toolace_length_curve.png)

To view the model output, you can either use the wandb built-in table we provided, or directly use the `sqlite3` command to query the replay buffer and see the model response by running the following command:
```bash
sqlite3 /PATH/TO/YOUR/BUFFER/toolace.db

> SELECT id, response FROM toolace_buffer ORDER BY id DESC LIMIT 1;
```

The resulting id and response:
```text
76800|<think>To calculate the potential financial exposure of the portfolio, I need to use the `risk.calculate_derivative_exposure` function. The portfolio consists of 10 options with a maturity date of April 15, 2023, and 5 futures that mature on June 30, 2023. We need to calculate the exposure for today and also for the end of this quarter. Today's date can be inferred from the current system date. The end of the quarter can be calculated based on the current date.</think>
<tool_call>[{"name": "risk.calculate_derivative_exposure", "arguments": {"portfolio": [{"derivative_type": "Option", "positions": [{"quantity": 10, "maturity_date": "2023-04-15"}]}, {"derivative_type": "Future", "positions": [{"quantity": 5, "maturity_date": "2023-06-30"}]}], "evaluation_date": "2023-03-31"}}, {"name": "risk.calculate_derivative_exposure", "arguments": {"portfolio": [{"derivative_type": "Option", "positions": [{"quantity": 10, "maturity_date": "2023-04-15"}]}, {"derivative_type": "Future", "positions": [{"quantity": 5, "maturity_date": "2023-06-30"}]}], "evaluation_date": "2023-06-30"}}]</tool_call>
```
2 changes: 1 addition & 1 deletion examples/grpo_toolcall/toolace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ buffer:
taskset:
name: toolace_data
storage_type: file
path: scripts/data_prepare/toolace_data
path: examples/grpo_toolcall/toolace_data
# format: []
rollout_args:
n: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ def create_dataset_files(output_dir, train_size=4096, test_size=100):


if __name__ == "__main__":
current_file_path = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_path}/webshop_data"
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = f"{current_file_dir}/webshop_data"
create_dataset_files(output_dir, train_size=4096, test_size=100)
2 changes: 1 addition & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ buffer:
taskset:
name: webshop
storage_type: file
path: 'scripts/data_prepare/webshop_data'
path: 'examples/grpo_webshop/webshop_data'
format:
prompt_key: 'task_id'
rollout_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:

def run(self) -> List[Experience]:
# assume the task_description is the game_file_path generated.
# see Trinity-RFT/script/data_prepare/get_alfworld_data.py
# see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py
game_file_path = self.task_desc
rollout_n = self.repeat_times
# TODO: Make parallel envs
Expand Down