Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,24 @@ The chart below shows performance based on this [commit](https://github.com/mode
![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png)

### 2. Countdown
First generate data, then run the benchmark:
To reproduce this experiment:
```bash
# Step 1: Generate data
python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path
# Step 2: Run benchmark
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path
python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct
```
#### Countdown Results
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d).
![View Results](../docs/sphinx_doc/assets/countdown-bench.png)

### 3. Guru
To reproduce this experiment:
```bash
python bench.py guru --model_path /path/to/Qwen/Qwen2.5-7B
```

#### Guru Results
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/fbf6c967bcd637bfd9f81fb4d7dd4961d7d5a407).
![View Results](../docs/sphinx_doc/assets/guru-bench.png)

*More benchmarks will be added soon!*

---
Expand Down
114 changes: 103 additions & 11 deletions benchmark/bench.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import argparse
import importlib
import os
import subprocess
import sys
import time

import torch
import torch.distributed as dist
import yaml

from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.common.constants import MODEL_PATH_ENV_VAR
from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
from trinity.utils.dlc_utils import get_dlc_env_vars


def set_engine_num(config, args):
config["cluster"]["node_num"] = args.node_num
config["cluster"]["gpu_per_node"] = args.gpu_per_node
batch_size = config["buffer"]["batch_size"]
batch_size = config["buffer"]["batch_size"] * config["algorithm"]["repeat_times"]
if config["mode"] == "train":
return

Expand Down Expand Up @@ -61,6 +63,84 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num


def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
"""Ensures the taskset path exists for the given dataset; generates it if necessary.

This function checks whether the 'path' specified in taskset_config exists. If not,
it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
the dataset at the default or provided location. The generator scripts are expected
to be located in the 'scripts/' subdirectory relative to this file.

Args:
dataset_name: Name of the dataset (e.g., "countdown", "guru").
Must be one of the supported datasets defined in `dataset_script_map`.
taskset_path: Path to the dataset.

Returns:
str: The resolved path to the dataset.

Raises:
ValueError: If the `dataset_name` is not supported.
FileNotFoundError: If the corresponding generator script does not exist.
ImportError: If the generator module fails to load.
AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
subprocess.CalledProcessError: If the generation script fails (due to check=True).

Side Effects:
- Modifies `taskset_config` by setting the "path" key to the resolved path.
- May create directories and files on disk via the external generation script.
- Executes a subprocess to run the dataset generation script.

Examples:
For dataset_name='guru' and taskset_config={"path": None},
this function will runs the following command and
generate the guru dataset to default location (DEFAULT_DATA_PATH in scripts/gen_guru_data.py):

```bash
python scripts/gen_guru_data.py --local_dir DEFAULT_DATA_PATH
```
"""
if taskset_path:
if os.path.exists(taskset_path):
return taskset_path
if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k":
return taskset_path

dataset_script_map = {
"countdown": "gen_countdown_data.py",
"guru": "gen_guru_data.py",
}
if dataset_name not in dataset_script_map:
raise ValueError(
f"Unsupported dataset: {dataset_name}. Please specify a valid taskset path."
)

base_dir = os.path.dirname(__file__)
script_filename = dataset_script_map[dataset_name]
script_module_name = script_filename[:-3] # remove .py

script_file_path = os.path.join(base_dir, "scripts", script_filename)
if not os.path.exists(script_file_path):
raise FileNotFoundError(f"Generator script not found: {script_file_path}")

spec = importlib.util.spec_from_file_location(script_module_name, script_file_path)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load spec for module: {script_module_name}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

if taskset_path is None:
if not hasattr(module, "DEFAULT_DATA_PATH"):
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
taskset_path = module.DEFAULT_DATA_PATH
taskset_path = os.path.realpath(taskset_path)

gen_script_path = os.path.join(base_dir, "scripts", script_filename)
subprocess.run([sys.executable, gen_script_path, "--local_dir", taskset_path], check=True)

return taskset_path


def prepare_configs(args, rank, current_time):
base_path = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -89,18 +169,19 @@ def prepare_configs(args, rank, current_time):
)
if args.critic_lr:
config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr
config["buffer"]["explorer_input"]["taskset"]["path"] = (
args.taskset_path
or os.environ.get("TASKSET_PATH")
or config["buffer"]["explorer_input"]["taskset"]["path"]
taskset_config = config["buffer"]["explorer_input"]["taskset"]
taskset_config["path"] = check_taskset_path(
args.dataset,
args.taskset_path or os.environ.get("TASKSET_PATH") or taskset_config["path"],
)
assert (
config["buffer"]["explorer_input"]["taskset"]["path"] is not None
), "Please specify taskset path."
if args.lr:
config["algorithm"]["optimizer"]["lr"] = args.lr
if args.sync_interval:
config["synchronizer"]["sync_interval"] = args.sync_interval
if args.sync_offset:
config["synchronizer"]["sync_offset"] = args.sync_offset
if args.sync_style:
config["synchronizer"]["sync_style"] = args.sync_style

with open(config_path, "w") as f:
yaml.dump(config, f, allow_unicode=True, sort_keys=False)
Expand Down Expand Up @@ -131,7 +212,7 @@ def main(args):
rank, current_time = 0, time.time()
config_path = prepare_configs(args, rank, current_time)
cmd_list = [
"python",
sys.executable,
"-m",
"trinity.cli.launcher",
"run",
Expand All @@ -142,12 +223,16 @@ def main(args):
dist.barrier()
dist.destroy_process_group()
cmd_list.append("--dlc")
if args.dataset == "guru":
base_path = os.path.dirname(os.path.abspath(__file__))
cmd_list.append("--plugin-dir")
cmd_list.append(os.path.join(base_path, "plugins"))
subprocess.run(cmd_list, check=True)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"])
parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru"])
parser.add_argument(
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
)
Expand Down Expand Up @@ -191,5 +276,12 @@ def main(args):
parser.add_argument(
"--sync_interval", type=int, default=None, help="Specify the sync interval."
)
parser.add_argument("--sync_offset", type=int, default=None, help="Specify the sync offset.")
parser.add_argument(
"--sync_style",
type=str,
default=None,
choices=[sync_style.value for sync_style in SyncStyle],
)
args = parser.parse_args()
main(args)
92 changes: 3 additions & 89 deletions benchmark/config/countdown-template.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mode: both
project: Trinity-RFT
group: countdown-bench
name: countdown-qwen2.5-1.5B
group: ${oc.env:TRINITY_GROUP,countdown-bench}
name: ${oc.env:TRINITY_NAME,countdown}
checkpoint_root_dir: placeholder
algorithm:
algorithm_type: ppo
Expand Down Expand Up @@ -72,102 +72,16 @@ trainer:
total_steps: 1000
enable_preview: true
grad_clip: 1.0
max_token_len_per_gpu: 6400
trainer_config:
actor_rollout_ref:
hybrid_engine: true
model:
external_lib: null
override_config: {}
enable_gradient_checkpointing: true
use_remove_padding: true
actor:
strategy: fsdp
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 6400
ppo_epochs: 1
shuffle: false
ulysses_sequence_parallel_size: 1
checkpoint:
load_contents:
- model
- optimizer
- extra
save_contents:
- model
- optimizer
- extra
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
fsdp_size: -1
ref:
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
fsdp_size: -1
log_prob_micro_batch_size_per_gpu: 8
log_prob_use_dynamic_bsz: true
log_prob_max_token_len_per_gpu: 6400
ulysses_sequence_parallel_size: 1
custom_reward_function:
path: null
name: compute_score
algorithm:
kl_penalty: low_var_kl
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
balance_batch: true
resume_mode: auto
resume_from_path: ''
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: false
del_local_ckpt_after_load: false
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
model:
override_config: {}
external_lib: null
enable_gradient_checkpointing: true
use_remove_padding: true
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: false
optimizer_offload: false
fsdp_size: -1
ppo_micro_batch_size_per_gpu: 8
forward_micro_batch_size_per_gpu: 8
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 12800
forward_max_token_len_per_gpu: 12800
ulysses_sequence_parallel_size: 1
ppo_epochs: 1
shuffle: false
grad_clip: 1.0
cliprange_value: 0.5
checkpoint:
load_contents:
- model
- optimizer
- extra
save_contents:
- model
- optimizer
- extra
monitor:
monitor_type: wandb
synchronizer:
Expand Down
4 changes: 2 additions & 2 deletions benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mode: both
project: Trinity-RFT
group: gsm8k-bench
name: gsm8k-qwen2.5-1.5B
group: ${oc.env:TRINITY_GROUP,gsm8k-bench}
name: ${oc.env:TRINITY_NAME,gsm8k}
checkpoint_root_dir: placeholder
algorithm:
algorithm_type: grpo
Expand Down
Loading