Skip to content

Commit 3d0f268

Browse files
committed
apply suggestions from reviews
1 parent a416d5a commit 3d0f268

File tree

5 files changed

+19
-14
lines changed

5 files changed

+19
-14
lines changed

benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct
7777
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d).
7878
![View Results](../docs/sphinx_doc/assets/countdown-bench.png)
7979

80-
### 3. Guru
80+
### 3. Guru-Math
8181
To reproduce this experiment:
8282
```bash
83-
python bench.py guru --model_path /path/to/Qwen/Qwen2.5-7B
83+
python bench.py guru_math --model_path /path/to/Qwen/Qwen2.5-7B
8484
```
8585

8686
#### Guru Results

benchmark/bench.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,16 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
8787
subprocess.CalledProcessError: If the generation script fails (due to check=True).
8888
8989
Side Effects:
90-
- Modifies `taskset_config` by setting the "path" key to the resolved path.
9190
- May create directories and files on disk via the external generation script.
9291
- Executes a subprocess to run the dataset generation script.
9392
9493
Examples:
95-
For dataset_name='guru' and taskset_config={"path": None},
94+
For dataset_name='guru_math' and taskset_config={"path": None},
9695
this function will runs the following command and
97-
generate the guru dataset to default location (DEFAULT_DATA_PATH in scripts/gen_guru_data.py):
96+
generate the guru_math dataset to default location (DEFAULT_DATA_PATH in scripts/gen_guru_math_data.py):
9897
9998
```bash
100-
python scripts/gen_guru_data.py --local_dir DEFAULT_DATA_PATH
99+
python scripts/gen_guru_math_data.py --local_dir DEFAULT_DATA_PATH
101100
```
102101
"""
103102
if taskset_path:
@@ -108,7 +107,7 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
108107

109108
dataset_script_map = {
110109
"countdown": "gen_countdown_data.py",
111-
"guru": "gen_guru_data.py",
110+
"guru_math": "gen_guru_math_data.py",
112111
}
113112
if dataset_name not in dataset_script_map:
114113
raise ValueError(
@@ -223,16 +222,21 @@ def main(args):
223222
dist.barrier()
224223
dist.destroy_process_group()
225224
cmd_list.append("--dlc")
226-
if args.dataset == "guru":
227-
base_path = os.path.dirname(os.path.abspath(__file__))
225+
226+
# load plugins
227+
base_path = os.path.dirname(os.path.abspath(__file__))
228+
plugin_dir = os.path.join(base_path, "plugins", args.dataset)
229+
if os.path.exists(plugin_dir):
228230
cmd_list.append("--plugin-dir")
229-
cmd_list.append(os.path.join(base_path, "plugins"))
231+
cmd_list.append(plugin_dir)
232+
233+
# run command
230234
subprocess.run(cmd_list, check=True)
231235

232236

233237
if __name__ == "__main__":
234238
parser = argparse.ArgumentParser()
235-
parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru"])
239+
parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru_math"])
236240
parser.add_argument(
237241
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
238242
)
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dic
485485
extra_info: dict with additional info for the score computation
486486
487487
Returns:
488-
Reward score (1.0 for correct, -1.0 for incorrect)
488+
Reward score (1.0 for correct, 0.0 for incorrect)
489489
"""
490490
# First assert intended generation and gt type
491491
model_output = str(solution_str)
@@ -513,7 +513,6 @@ def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dic
513513
except Exception:
514514
correct = False
515515

516-
# reward = 1.0 if correct else -1.0
517516
reward = 1.0 if correct else 0.0
518517
acc = correct
519518

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from datasets import load_dataset
55
from huggingface_hub import hf_hub_download
66

7-
DEFAULT_DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "data", "guru")
7+
DEFAULT_DATA_PATH = os.path.join(
8+
os.path.dirname(os.path.abspath(__file__)), "..", "data", "guru_math"
9+
)
810

911

1012
def process_fn(example, idx):

0 commit comments

Comments
 (0)