diff --git a/docs/sphinx_doc/assets/toolace_3b_response.png b/docs/sphinx_doc/assets/toolace_3b_response.png
new file mode 100644
index 0000000000..0409e2b77d
Binary files /dev/null and b/docs/sphinx_doc/assets/toolace_3b_response.png differ
diff --git a/docs/sphinx_doc/assets/toolace_3b_rewards.png b/docs/sphinx_doc/assets/toolace_3b_rewards.png
new file mode 100644
index 0000000000..8c8920227b
Binary files /dev/null and b/docs/sphinx_doc/assets/toolace_3b_rewards.png differ
diff --git a/docs/sphinx_doc/assets/toolace_length_curve.png b/docs/sphinx_doc/assets/toolace_length_curve.png
new file mode 100644
index 0000000000..d1bc605462
Binary files /dev/null and b/docs/sphinx_doc/assets/toolace_length_curve.png differ
diff --git a/docs/sphinx_doc/assets/toolace_reward_curve.png b/docs/sphinx_doc/assets/toolace_reward_curve.png
index c062466783..e32df0cbc3 100644
Binary files a/docs/sphinx_doc/assets/toolace_reward_curve.png and b/docs/sphinx_doc/assets/toolace_reward_curve.png differ
diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
index 3cf5b89145..7169731b9c 100644
--- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md
+++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md
@@ -24,10 +24,10 @@ Our dataset follows the format in Huggingface datasets library, so we should cor
Just check the data preparation scripts and run the following command.
```bash
# For ALFworld env
-python scripts/data_prepare/get_alfworld_data.py
+python examples/grpo_alfworld/get_alfworld_data.py
# For WebShop env
-python scripts/data_prepare/get_webshop_data.py
+python examples/grpo_webshop/get_webshop_data.py
```
The task is described as an environment instead of a single prompt.
diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml
index f6079ad55e..7691c13015 100644
--- a/examples/grpo_alfworld/alfworld.yaml
+++ b/examples/grpo_alfworld/alfworld.yaml
@@ -20,7 +20,7 @@ buffer:
taskset:
name: alfworld
storage_type: file
- path: 'scripts/data_prepare/alfworld_data'
+ path: 'examples/grpo_alfworld/alfworld_data'
format:
prompt_key: 'game_file'
rollout_args:
diff --git a/scripts/data_prepare/get_alfworld_data.py b/examples/grpo_alfworld/get_alfworld_data.py
similarity index 94%
rename from scripts/data_prepare/get_alfworld_data.py
rename to examples/grpo_alfworld/get_alfworld_data.py
index b55a04435a..93e6c8a1b3 100644
--- a/scripts/data_prepare/get_alfworld_data.py
+++ b/examples/grpo_alfworld/get_alfworld_data.py
@@ -60,6 +60,6 @@ def create_dataset_files(output_dir, train_size=1024, test_size=100):
if __name__ == "__main__":
- current_file_path = os.path.dirname(os.path.abspath(__file__))
- output_dir = f"{current_file_path}/alfworld_data"
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/alfworld_data"
create_dataset_files(output_dir, train_size=1024, test_size=100)
diff --git a/scripts/data_prepare/get_sciworld_data.py b/examples/grpo_sciworld/get_sciworld_data.py
similarity index 97%
rename from scripts/data_prepare/get_sciworld_data.py
rename to examples/grpo_sciworld/get_sciworld_data.py
index e94624c155..a10da8d2ee 100644
--- a/scripts/data_prepare/get_sciworld_data.py
+++ b/examples/grpo_sciworld/get_sciworld_data.py
@@ -107,8 +107,8 @@ def create_dataset_files(output_dir, train_task_names, test_task_names, jar_path
f"JAR file not found at {jar_path}, please set the jar path mannually."
)
- current_file_path = os.path.dirname(os.path.abspath(__file__))
- output_dir = f"{current_file_path}/sciworld_data"
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/sciworld_data"
train_task_names = [
"boil",
"melt",
diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml
index 799b2df800..1ec022a7d3 100644
--- a/examples/grpo_sciworld/sciworld.yaml
+++ b/examples/grpo_sciworld/sciworld.yaml
@@ -20,7 +20,7 @@ buffer:
taskset:
name: sciworld
storage_type: file
- path: 'scripts/data_prepare/sciworld_data'
+ path: 'examples/grpo_sciworld/sciworld_data'
format:
prompt_key: 'task_desc'
rollout_args:
diff --git a/examples/grpo_toolcall/README.md b/examples/grpo_toolcall/README.md
index 3418178b05..ddcbd57c07 100644
--- a/examples/grpo_toolcall/README.md
+++ b/examples/grpo_toolcall/README.md
@@ -8,10 +8,74 @@ The config files are located in [`toolace.yaml`](toolace.yaml) and [`train_toola
## How to run
-To preprocess the data into the format required by our `toolcall_workflow`, run the following command: `python scripts/data_prepare/get_toolace_data.py`.
+To preprocess the data into the format required by our `toolcall_workflow`, run the following command: `python examples/grpo_toolcall/get_toolace_data.py`.
Then fill in the config file `toolace.yaml` and run the following command: `trinity run --config examples/grpo_toolcall/toolace.yaml`.
+## Preventing reward hacking
+In our initial experiments on the 3B model, we found that the reward structure of ToolAce will result in reward hacking, so we add modification on the reward structure of the toolcall workflow. (Our workflow design is flexible, and please feel free to modify it according to your needs.)
+
+The original reward for format checking is:
+```python
+
+def compute_score_v0(solution_str, ground_truth):
+ ....
+ if "" not in output_string or "" not in output_string:
+ return 0
+ ...
+```
+
+It simply checks whether the output contains `` and ``.
+But it did not check whether there are only one `` and one `` in the output, nor did it check whether the `` and `` are before the `` tags.
+
+This results in reward hacking in 3B model.
+While the reward curve seems to converge, the model will generate endless `` and `` tags, resulting in overlength outputs.
+
+
+

+

+
+
+The response looks like this:
+```
+To gather the number of live football events this week, I need to call the Sports Number live events function. [{"name": "Sports Number live events", "arguments": {"sport": "football", "string_range": "this week"}}] The arguments include the sport as football and the string_range as this week to filter the events. This should provide the number of live football events happening this week. ...
+```
+
+To fix this, we add the following code to the `compute_score_v0` function:
+```python
+ ...
+ # added rule1
+ if solution_str.count("") != 1 or solution_str.count("") != 1:
+ return 0
+
+ # added rule2
+ think_end_pos = solution_str.find("")
+ tool_call_start_pos = solution_str.find("")
+
+ if tool_call_start_pos != -1 and think_end_pos > tool_call_start_pos:
+ return 0
+ ...
+```
+
+With this fix on reward hacking, the training will successfully converges. With Qwen2.5-7B-Instruct model, it takes around 3 hours on 8 H20 GPUs to train 9.6k data for one epoch.
+
## Reward curve results
+Below is the reward curve of the trained Qwen2.5-7B-Instruct model.

+
+The response length is also steady.
+
+
+To view the model output, you can either use the wandb built-in table we provided, or directly use the `sqlite3` command to query the replay buffer and see the model response by running the following command:
+```bash
+sqlite3 /PATH/TO/YOUR/BUFFER/toolace.db
+
+> SELECT id, response FROM toolace_buffer ORDER BY id DESC LIMIT 1;
+```
+
+The resulting id and response:
+```text
+76800|To calculate the potential financial exposure of the portfolio, I need to use the `risk.calculate_derivative_exposure` function. The portfolio consists of 10 options with a maturity date of April 15, 2023, and 5 futures that mature on June 30, 2023. We need to calculate the exposure for today and also for the end of this quarter. Today's date can be inferred from the current system date. The end of the quarter can be calculated based on the current date.
+[{"name": "risk.calculate_derivative_exposure", "arguments": {"portfolio": [{"derivative_type": "Option", "positions": [{"quantity": 10, "maturity_date": "2023-04-15"}]}, {"derivative_type": "Future", "positions": [{"quantity": 5, "maturity_date": "2023-06-30"}]}], "evaluation_date": "2023-03-31"}}, {"name": "risk.calculate_derivative_exposure", "arguments": {"portfolio": [{"derivative_type": "Option", "positions": [{"quantity": 10, "maturity_date": "2023-04-15"}]}, {"derivative_type": "Future", "positions": [{"quantity": 5, "maturity_date": "2023-06-30"}]}], "evaluation_date": "2023-06-30"}}]
+```
diff --git a/scripts/data_prepare/get_toolace_data.py b/examples/grpo_toolcall/get_toolace_data.py
similarity index 100%
rename from scripts/data_prepare/get_toolace_data.py
rename to examples/grpo_toolcall/get_toolace_data.py
diff --git a/examples/grpo_toolcall/toolace.yaml b/examples/grpo_toolcall/toolace.yaml
index 1923bb032b..9d002bc641 100644
--- a/examples/grpo_toolcall/toolace.yaml
+++ b/examples/grpo_toolcall/toolace.yaml
@@ -21,7 +21,7 @@ buffer:
taskset:
name: toolace_data
storage_type: file
- path: scripts/data_prepare/toolace_data
+ path: examples/grpo_toolcall/toolace_data
# format: []
rollout_args:
n: 8
diff --git a/scripts/data_prepare/get_webshop_data.py b/examples/grpo_webshop/get_webshop_data.py
similarity index 92%
rename from scripts/data_prepare/get_webshop_data.py
rename to examples/grpo_webshop/get_webshop_data.py
index 7d9847b707..61ea93a5ea 100644
--- a/scripts/data_prepare/get_webshop_data.py
+++ b/examples/grpo_webshop/get_webshop_data.py
@@ -42,6 +42,6 @@ def create_dataset_files(output_dir, train_size=4096, test_size=100):
if __name__ == "__main__":
- current_file_path = os.path.dirname(os.path.abspath(__file__))
- output_dir = f"{current_file_path}/webshop_data"
+ current_file_dir = os.path.dirname(os.path.abspath(__file__))
+ output_dir = f"{current_file_dir}/webshop_data"
create_dataset_files(output_dir, train_size=4096, test_size=100)
diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml
index a5ea2a310e..c8722ed54b 100644
--- a/examples/grpo_webshop/webshop.yaml
+++ b/examples/grpo_webshop/webshop.yaml
@@ -20,7 +20,7 @@ buffer:
taskset:
name: webshop
storage_type: file
- path: 'scripts/data_prepare/webshop_data'
+ path: 'examples/grpo_webshop/webshop_data'
format:
prompt_key: 'task_id'
rollout_args:
diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
index 7c0db027a0..40173454e3 100644
--- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
+++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py
@@ -146,7 +146,7 @@ def generate_env_inference_samples(self, env, rollout_num) -> List[Experience]:
def run(self) -> List[Experience]:
# assume the task_description is the game_file_path generated.
- # see Trinity-RFT/script/data_prepare/get_alfworld_data.py
+ # see Trinity-RFT/examples/grpo_alfworld/get_alfworld_data.py
game_file_path = self.task_desc
rollout_n = self.repeat_times
# TODO: Make parallel envs