Skip to content

Commit a416d5a

Browse files
committed
Add benchmark scripts for Guru-Math
1 parent fbf6c96 commit a416d5a

File tree

12 files changed

+865
-223
lines changed

12 files changed

+865
-223
lines changed

benchmark/README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,24 @@ The chart below shows performance based on this [commit](https://github.com/mode
6969
![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png)
7070

7171
### 2. Countdown
72-
First generate data, then run the benchmark:
72+
To reproduce this experiment:
7373
```bash
74-
# Step 1: Generate data
75-
python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path
76-
# Step 2: Run benchmark
77-
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path
74+
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct
7875
```
7976
#### Countdown Results
8077
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d).
8178
![View Results](../docs/sphinx_doc/assets/countdown-bench.png)
8279

80+
### 3. Guru
81+
To reproduce this experiment:
82+
```bash
83+
python bench.py guru --model_path /path/to/Qwen/Qwen2.5-7B
84+
```
85+
86+
#### Guru Results
87+
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/fbf6c967bcd637bfd9f81fb4d7dd4961d7d5a407).
88+
![View Results](../docs/sphinx_doc/assets/guru-bench.png)
89+
8390
*More benchmarks will be added soon!*
8491

8592
---

benchmark/bench.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import argparse
2+
import importlib
23
import os
34
import subprocess
5+
import sys
46
import time
57

68
import torch
79
import torch.distributed as dist
810
import yaml
911

1012
from trinity.algorithm.algorithm import ALGORITHM_TYPE
11-
from trinity.common.constants import MODEL_PATH_ENV_VAR
13+
from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
1214
from trinity.utils.dlc_utils import get_dlc_env_vars
1315

1416

1517
def set_engine_num(config, args):
1618
config["cluster"]["node_num"] = args.node_num
1719
config["cluster"]["gpu_per_node"] = args.gpu_per_node
18-
batch_size = config["buffer"]["batch_size"]
20+
batch_size = config["buffer"]["batch_size"] * config["algorithm"]["repeat_times"]
1921
if config["mode"] == "train":
2022
return
2123

@@ -61,6 +63,84 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
6163
config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num
6264

6365

66+
def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
67+
"""Ensures the taskset path exists for the given dataset; generates it if necessary.
68+
69+
This function checks whether the 'path' specified in taskset_config exists. If not,
70+
it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
71+
the dataset at the default or provided location. The generator scripts are expected
72+
to be located in the 'scripts/' subdirectory relative to this file.
73+
74+
Args:
75+
dataset_name: Name of the dataset (e.g., "countdown", "guru").
76+
Must be one of the supported datasets defined in `dataset_script_map`.
77+
taskset_path: Path to the dataset.
78+
79+
Returns:
80+
str: The resolved path to the dataset.
81+
82+
Raises:
83+
ValueError: If the `dataset_name` is not supported.
84+
FileNotFoundError: If the corresponding generator script does not exist.
85+
ImportError: If the generator module fails to load.
86+
AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
87+
subprocess.CalledProcessError: If the generation script fails (due to check=True).
88+
89+
Side Effects:
90+
- Modifies `taskset_config` by setting the "path" key to the resolved path.
91+
- May create directories and files on disk via the external generation script.
92+
- Executes a subprocess to run the dataset generation script.
93+
94+
Examples:
95+
For dataset_name='guru' and taskset_config={"path": None},
96+
this function will runs the following command and
97+
generate the guru dataset to default location (DEFAULT_DATA_PATH in scripts/gen_guru_data.py):
98+
99+
```bash
100+
python scripts/gen_guru_data.py --local_dir DEFAULT_DATA_PATH
101+
```
102+
"""
103+
if taskset_path:
104+
if os.path.exists(taskset_path):
105+
return taskset_path
106+
if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k":
107+
return taskset_path
108+
109+
dataset_script_map = {
110+
"countdown": "gen_countdown_data.py",
111+
"guru": "gen_guru_data.py",
112+
}
113+
if dataset_name not in dataset_script_map:
114+
raise ValueError(
115+
f"Unsupported dataset: {dataset_name}. Please specify a valid taskset path."
116+
)
117+
118+
base_dir = os.path.dirname(__file__)
119+
script_filename = dataset_script_map[dataset_name]
120+
script_module_name = script_filename[:-3] # remove .py
121+
122+
script_file_path = os.path.join(base_dir, "scripts", script_filename)
123+
if not os.path.exists(script_file_path):
124+
raise FileNotFoundError(f"Generator script not found: {script_file_path}")
125+
126+
spec = importlib.util.spec_from_file_location(script_module_name, script_file_path)
127+
if spec is None or spec.loader is None:
128+
raise ImportError(f"Could not load spec for module: {script_module_name}")
129+
module = importlib.util.module_from_spec(spec)
130+
spec.loader.exec_module(module)
131+
132+
if taskset_path is None:
133+
if not hasattr(module, "DEFAULT_DATA_PATH"):
134+
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
135+
taskset_path = module.DEFAULT_DATA_PATH
136+
taskset_path = os.path.realpath(taskset_path)
137+
138+
gen_script_path = os.path.join(base_dir, "scripts", script_filename)
139+
subprocess.run([sys.executable, gen_script_path, "--local_dir", taskset_path], check=True)
140+
141+
return taskset_path
142+
143+
64144
def prepare_configs(args, rank, current_time):
65145
base_path = os.path.dirname(os.path.abspath(__file__))
66146

@@ -89,18 +169,19 @@ def prepare_configs(args, rank, current_time):
89169
)
90170
if args.critic_lr:
91171
config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr
92-
config["buffer"]["explorer_input"]["taskset"]["path"] = (
93-
args.taskset_path
94-
or os.environ.get("TASKSET_PATH")
95-
or config["buffer"]["explorer_input"]["taskset"]["path"]
172+
taskset_config = config["buffer"]["explorer_input"]["taskset"]
173+
taskset_config["path"] = check_taskset_path(
174+
args.dataset,
175+
args.taskset_path or os.environ.get("TASKSET_PATH") or taskset_config["path"],
96176
)
97-
assert (
98-
config["buffer"]["explorer_input"]["taskset"]["path"] is not None
99-
), "Please specify taskset path."
100177
if args.lr:
101178
config["algorithm"]["optimizer"]["lr"] = args.lr
102179
if args.sync_interval:
103180
config["synchronizer"]["sync_interval"] = args.sync_interval
181+
if args.sync_offset:
182+
config["synchronizer"]["sync_offset"] = args.sync_offset
183+
if args.sync_style:
184+
config["synchronizer"]["sync_style"] = args.sync_style
104185

105186
with open(config_path, "w") as f:
106187
yaml.dump(config, f, allow_unicode=True, sort_keys=False)
@@ -131,7 +212,7 @@ def main(args):
131212
rank, current_time = 0, time.time()
132213
config_path = prepare_configs(args, rank, current_time)
133214
cmd_list = [
134-
"python",
215+
sys.executable,
135216
"-m",
136217
"trinity.cli.launcher",
137218
"run",
@@ -142,12 +223,16 @@ def main(args):
142223
dist.barrier()
143224
dist.destroy_process_group()
144225
cmd_list.append("--dlc")
226+
if args.dataset == "guru":
227+
base_path = os.path.dirname(os.path.abspath(__file__))
228+
cmd_list.append("--plugin-dir")
229+
cmd_list.append(os.path.join(base_path, "plugins"))
145230
subprocess.run(cmd_list, check=True)
146231

147232

148233
if __name__ == "__main__":
149234
parser = argparse.ArgumentParser()
150-
parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"])
235+
parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru"])
151236
parser.add_argument(
152237
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
153238
)
@@ -191,5 +276,12 @@ def main(args):
191276
parser.add_argument(
192277
"--sync_interval", type=int, default=None, help="Specify the sync interval."
193278
)
279+
parser.add_argument("--sync_offset", type=int, default=None, help="Specify the sync offset.")
280+
parser.add_argument(
281+
"--sync_style",
282+
type=str,
283+
default=None,
284+
choices=[sync_style.value for sync_style in SyncStyle],
285+
)
194286
args = parser.parse_args()
195287
main(args)

benchmark/config/countdown-template.yaml

Lines changed: 3 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mode: both
22
project: Trinity-RFT
3-
group: countdown-bench
4-
name: countdown-qwen2.5-1.5B
3+
group: ${oc.env:TRINITY_GROUP,countdown-bench}
4+
name: ${oc.env:TRINITY_NAME,countdown}
55
checkpoint_root_dir: placeholder
66
algorithm:
77
algorithm_type: ppo
@@ -72,102 +72,16 @@ trainer:
7272
total_steps: 1000
7373
enable_preview: true
7474
grad_clip: 1.0
75+
max_token_len_per_gpu: 6400
7576
trainer_config:
76-
actor_rollout_ref:
77-
hybrid_engine: true
78-
model:
79-
external_lib: null
80-
override_config: {}
81-
enable_gradient_checkpointing: true
82-
use_remove_padding: true
83-
actor:
84-
strategy: fsdp
85-
ppo_micro_batch_size_per_gpu: 4
86-
use_dynamic_bsz: true
87-
ppo_max_token_len_per_gpu: 6400
88-
ppo_epochs: 1
89-
shuffle: false
90-
ulysses_sequence_parallel_size: 1
91-
checkpoint:
92-
load_contents:
93-
- model
94-
- optimizer
95-
- extra
96-
save_contents:
97-
- model
98-
- optimizer
99-
- extra
100-
fsdp_config:
101-
wrap_policy:
102-
min_num_params: 0
103-
param_offload: false
104-
optimizer_offload: false
105-
fsdp_size: -1
106-
ref:
107-
fsdp_config:
108-
wrap_policy:
109-
min_num_params: 0
110-
param_offload: false
111-
optimizer_offload: false
112-
fsdp_size: -1
113-
log_prob_micro_batch_size_per_gpu: 8
114-
log_prob_use_dynamic_bsz: true
115-
log_prob_max_token_len_per_gpu: 6400
116-
ulysses_sequence_parallel_size: 1
117-
custom_reward_function:
118-
path: null
119-
name: compute_score
120-
algorithm:
121-
kl_penalty: low_var_kl
122-
kl_ctrl:
123-
type: fixed
124-
kl_coef: 0.001
125-
trainer:
126-
balance_batch: true
127-
resume_mode: auto
128-
resume_from_path: ''
129-
critic_warmup: 0
130-
default_hdfs_dir: null
131-
remove_previous_ckpt_in_save: false
132-
del_local_ckpt_after_load: false
133-
max_actor_ckpt_to_keep: null
134-
max_critic_ckpt_to_keep: null
13577
critic:
136-
strategy: fsdp
13778
optim:
13879
lr: 1e-5
13980
lr_warmup_steps_ratio: 0.0
14081
warmup_style: constant
141-
model:
142-
override_config: {}
143-
external_lib: null
144-
enable_gradient_checkpointing: true
145-
use_remove_padding: true
146-
fsdp_config:
147-
wrap_policy:
148-
min_num_params: 0
149-
param_offload: false
150-
optimizer_offload: false
151-
fsdp_size: -1
152-
ppo_micro_batch_size_per_gpu: 8
153-
forward_micro_batch_size_per_gpu: 8
154-
use_dynamic_bsz: true
15582
ppo_max_token_len_per_gpu: 12800
15683
forward_max_token_len_per_gpu: 12800
157-
ulysses_sequence_parallel_size: 1
158-
ppo_epochs: 1
159-
shuffle: false
160-
grad_clip: 1.0
16184
cliprange_value: 0.5
162-
checkpoint:
163-
load_contents:
164-
- model
165-
- optimizer
166-
- extra
167-
save_contents:
168-
- model
169-
- optimizer
170-
- extra
17185
monitor:
17286
monitor_type: wandb
17387
synchronizer:

benchmark/config/gsm8k-template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mode: both
22
project: Trinity-RFT
3-
group: gsm8k-bench
4-
name: gsm8k-qwen2.5-1.5B
3+
group: ${oc.env:TRINITY_GROUP,gsm8k-bench}
4+
name: ${oc.env:TRINITY_NAME,gsm8k}
55
checkpoint_root_dir: placeholder
66
algorithm:
77
algorithm_type: grpo

0 commit comments

Comments
 (0)