diff --git a/benchmark/README.md b/benchmark/README.md
index 28aff1a1bb..4fb67d06e5 100644
--- a/benchmark/README.md
+++ b/benchmark/README.md
@@ -69,17 +69,24 @@ The chart below shows performance based on this [commit](https://github.com/mode

### 2. Countdown
-First generate data, then run the benchmark:
+To reproduce this experiment:
```bash
-# Step 1: Generate data
-python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path
-# Step 2: Run benchmark
-python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path
+python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct
```
#### Countdown Results
The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d).

+### 3. Guru-Math
+To reproduce this experiment:
+```bash
+python bench.py guru_math --model_path /path/to/Qwen/Qwen2.5-7B
+```
+
+#### Guru Results
+The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/fbf6c967bcd637bfd9f81fb4d7dd4961d7d5a407).
+
+
*More benchmarks will be added soon!*
---
diff --git a/benchmark/bench.py b/benchmark/bench.py
index d4fbdf761f..6adf062d9d 100644
--- a/benchmark/bench.py
+++ b/benchmark/bench.py
@@ -1,6 +1,8 @@
import argparse
+import importlib
import os
import subprocess
+import sys
import time
import torch
@@ -8,14 +10,14 @@
import yaml
from trinity.algorithm.algorithm import ALGORITHM_TYPE
-from trinity.common.constants import MODEL_PATH_ENV_VAR
+from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
from trinity.utils.dlc_utils import get_dlc_env_vars
def set_engine_num(config, args):
config["cluster"]["node_num"] = args.node_num
config["cluster"]["gpu_per_node"] = args.gpu_per_node
- batch_size = config["buffer"]["batch_size"]
+ batch_size = config["buffer"]["batch_size"] * config["algorithm"]["repeat_times"]
if config["mode"] == "train":
return
@@ -61,6 +63,83 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num
+def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
+ """Ensures the taskset path exists for the given dataset; generates it if necessary.
+
+ This function checks whether `taskset_path` exists. If not,
+ it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
+ the dataset at the default or provided location. The generator scripts are expected
+ to be located in the 'scripts/' subdirectory relative to this file.
+
+ Args:
+ dataset_name: Name of the dataset (e.g., "countdown", "guru").
+ Must be one of the supported datasets defined in `dataset_script_map`.
+ taskset_path: Path to the dataset.
+
+ Returns:
+ str: The resolved path to the dataset.
+
+ Raises:
+ ValueError: If the `dataset_name` is not supported.
+ FileNotFoundError: If the corresponding generator script does not exist.
+ ImportError: If the generator module fails to load.
+ AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
+ subprocess.CalledProcessError: If the generation script fails (due to check=True).
+
+ Side Effects:
+ - May create directories and files on disk via the external generation script.
+ - Executes a subprocess to run the dataset generation script.
+
+ Examples:
+ For dataset_name='guru_math' and taskset_path=None, this function will runs the
+ following command and generate the guru_math dataset to default location
+ (DEFAULT_DATA_PATH in scripts/gen_guru_math_data.py):
+
+ ```bash
+ python scripts/gen_guru_math_data.py --local_dir DEFAULT_DATA_PATH
+ ```
+ """
+ if taskset_path:
+ if os.path.exists(taskset_path):
+ return taskset_path
+ if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k":
+ return taskset_path
+
+ dataset_script_map = {
+ "countdown": "gen_countdown_data.py",
+ "guru_math": "gen_guru_math_data.py",
+ }
+ if dataset_name not in dataset_script_map:
+ raise ValueError(
+ f"Unsupported dataset: {dataset_name}. Please specify a valid taskset path."
+ )
+
+ base_dir = os.path.dirname(__file__)
+ script_filename = dataset_script_map[dataset_name]
+ script_module_name = script_filename[:-3] # remove .py
+
+ script_file_path = os.path.join(base_dir, "scripts", script_filename)
+ if not os.path.exists(script_file_path):
+ raise FileNotFoundError(f"Generator script not found: {script_file_path}")
+
+ spec = importlib.util.spec_from_file_location(script_module_name, script_file_path)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"Could not load spec for module: {script_module_name}")
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+
+ if taskset_path is None:
+ if not hasattr(module, "DEFAULT_DATA_PATH"):
+ raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
+ taskset_path = module.DEFAULT_DATA_PATH
+ taskset_path = os.path.realpath(taskset_path)
+
+ gen_script_path = os.path.join(base_dir, "scripts", script_filename)
+ subprocess.run([sys.executable, gen_script_path, "--local_dir", taskset_path], check=True)
+
+ return taskset_path
+
+
def prepare_configs(args, rank, current_time):
base_path = os.path.dirname(os.path.abspath(__file__))
@@ -89,18 +168,19 @@ def prepare_configs(args, rank, current_time):
)
if args.critic_lr:
config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr
- config["buffer"]["explorer_input"]["taskset"]["path"] = (
- args.taskset_path
- or os.environ.get("TASKSET_PATH")
- or config["buffer"]["explorer_input"]["taskset"]["path"]
+ taskset_config = config["buffer"]["explorer_input"]["taskset"]
+ taskset_config["path"] = check_taskset_path(
+ args.dataset,
+ args.taskset_path or os.environ.get("TASKSET_PATH") or taskset_config["path"],
)
- assert (
- config["buffer"]["explorer_input"]["taskset"]["path"] is not None
- ), "Please specify taskset path."
if args.lr:
config["algorithm"]["optimizer"]["lr"] = args.lr
if args.sync_interval:
config["synchronizer"]["sync_interval"] = args.sync_interval
+ if args.sync_offset:
+ config["synchronizer"]["sync_offset"] = args.sync_offset
+ if args.sync_style:
+ config["synchronizer"]["sync_style"] = args.sync_style
with open(config_path, "w") as f:
yaml.dump(config, f, allow_unicode=True, sort_keys=False)
@@ -131,7 +211,7 @@ def main(args):
rank, current_time = 0, time.time()
config_path = prepare_configs(args, rank, current_time)
cmd_list = [
- "python",
+ sys.executable,
"-m",
"trinity.cli.launcher",
"run",
@@ -142,12 +222,21 @@ def main(args):
dist.barrier()
dist.destroy_process_group()
cmd_list.append("--dlc")
+
+ # load plugins
+ base_path = os.path.dirname(os.path.abspath(__file__))
+ plugin_dir = os.path.join(base_path, "plugins", args.dataset)
+ if os.path.exists(plugin_dir):
+ cmd_list.append("--plugin-dir")
+ cmd_list.append(plugin_dir)
+
+ # run command
subprocess.run(cmd_list, check=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"])
+ parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru_math"])
parser.add_argument(
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
)
@@ -191,5 +280,12 @@ def main(args):
parser.add_argument(
"--sync_interval", type=int, default=None, help="Specify the sync interval."
)
+ parser.add_argument("--sync_offset", type=int, default=None, help="Specify the sync offset.")
+ parser.add_argument(
+ "--sync_style",
+ type=str,
+ default=None,
+ choices=[sync_style.value for sync_style in SyncStyle],
+ )
args = parser.parse_args()
main(args)
diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml
index 96726c51e6..d81cfb6759 100644
--- a/benchmark/config/countdown-template.yaml
+++ b/benchmark/config/countdown-template.yaml
@@ -1,7 +1,7 @@
mode: both
project: Trinity-RFT
-group: countdown-bench
-name: countdown-qwen2.5-1.5B
+group: ${oc.env:TRINITY_GROUP,countdown-bench}
+name: ${oc.env:TRINITY_NAME,countdown}
checkpoint_root_dir: placeholder
algorithm:
algorithm_type: ppo
@@ -72,102 +72,16 @@ trainer:
total_steps: 1000
enable_preview: true
grad_clip: 1.0
+ max_token_len_per_gpu: 6400
trainer_config:
- actor_rollout_ref:
- hybrid_engine: true
- model:
- external_lib: null
- override_config: {}
- enable_gradient_checkpointing: true
- use_remove_padding: true
- actor:
- strategy: fsdp
- ppo_micro_batch_size_per_gpu: 4
- use_dynamic_bsz: true
- ppo_max_token_len_per_gpu: 6400
- ppo_epochs: 1
- shuffle: false
- ulysses_sequence_parallel_size: 1
- checkpoint:
- load_contents:
- - model
- - optimizer
- - extra
- save_contents:
- - model
- - optimizer
- - extra
- fsdp_config:
- wrap_policy:
- min_num_params: 0
- param_offload: false
- optimizer_offload: false
- fsdp_size: -1
- ref:
- fsdp_config:
- wrap_policy:
- min_num_params: 0
- param_offload: false
- optimizer_offload: false
- fsdp_size: -1
- log_prob_micro_batch_size_per_gpu: 8
- log_prob_use_dynamic_bsz: true
- log_prob_max_token_len_per_gpu: 6400
- ulysses_sequence_parallel_size: 1
- custom_reward_function:
- path: null
- name: compute_score
- algorithm:
- kl_penalty: low_var_kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.001
- trainer:
- balance_batch: true
- resume_mode: auto
- resume_from_path: ''
- critic_warmup: 0
- default_hdfs_dir: null
- remove_previous_ckpt_in_save: false
- del_local_ckpt_after_load: false
- max_actor_ckpt_to_keep: null
- max_critic_ckpt_to_keep: null
critic:
- strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
- model:
- override_config: {}
- external_lib: null
- enable_gradient_checkpointing: true
- use_remove_padding: true
- fsdp_config:
- wrap_policy:
- min_num_params: 0
- param_offload: false
- optimizer_offload: false
- fsdp_size: -1
- ppo_micro_batch_size_per_gpu: 8
- forward_micro_batch_size_per_gpu: 8
- use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 12800
forward_max_token_len_per_gpu: 12800
- ulysses_sequence_parallel_size: 1
- ppo_epochs: 1
- shuffle: false
- grad_clip: 1.0
cliprange_value: 0.5
- checkpoint:
- load_contents:
- - model
- - optimizer
- - extra
- save_contents:
- - model
- - optimizer
- - extra
monitor:
monitor_type: wandb
synchronizer:
diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml
index 9e602bfe52..a967589fe9 100644
--- a/benchmark/config/gsm8k-template.yaml
+++ b/benchmark/config/gsm8k-template.yaml
@@ -1,7 +1,7 @@
mode: both
project: Trinity-RFT
-group: gsm8k-bench
-name: gsm8k-qwen2.5-1.5B
+group: ${oc.env:TRINITY_GROUP,gsm8k-bench}
+name: ${oc.env:TRINITY_NAME,gsm8k}
checkpoint_root_dir: placeholder
algorithm:
algorithm_type: grpo
diff --git a/benchmark/config/guru_math-template.yaml b/benchmark/config/guru_math-template.yaml
new file mode 100644
index 0000000000..7538171826
--- /dev/null
+++ b/benchmark/config/guru_math-template.yaml
@@ -0,0 +1,74 @@
+mode: both
+project: Trinity-RFT
+group: ${oc.env:TRINITY_GROUP,guru_math-bench}
+name: ${oc.env:TRINITY_NAME,guru_math}
+checkpoint_root_dir: placeholder
+model:
+ model_path: Qwen/Qwen2.5-7B
+ max_prompt_tokens: 4096
+ max_response_tokens: 8192
+algorithm:
+ algorithm_type: grpo
+ repeat_times: 16
+ kl_loss_fn_args:
+ kl_coef: 0.0
+ optimizer:
+ lr: 1e-6
+ weight_decay: 0.1
+ lr_warmup_steps: 80
+ warmup_style: constant
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+data_processor:
+ experience_pipeline:
+ save_input: false
+buffer:
+ total_epochs: 1
+ batch_size: 60
+ explorer_input:
+ default_workflow_type: math_boxed_workflow
+ default_reward_fn_type: math_boxed_reward_naive_dapo
+ taskset:
+ name: math
+ storage_type: file
+ path: null
+ format:
+ prompt_key: question
+ response_key: ground_truth
+ system_prompt: "You are a helpful assistant. To answer a query from the user, please first thinks through the question step-by-step inside ..., then provides the final response to user."
+ reply_prefix: ""
+ rollout_args:
+ temperature: 1.0
+ logprobs: 0
+ eval_tasksets: []
+ trainer_input:
+ experience_buffer:
+ name: math_buffer
+ storage_type: queue
+ replay_buffer:
+ enable: false
+explorer:
+ eval_interval: 10
+ runner_per_model: 8
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 3
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ seed: 42
+synchronizer:
+ sync_style: fixed
+ sync_method: nccl
+ sync_interval: 8
+ sync_timeout: 2400
+trainer:
+ trainer_type: verl
+ save_interval: 80
+ enable_preview: true
+ grad_clip: 1.0
+ max_token_len_per_gpu: 24576
+monitor:
+ monitor_type: wandb
diff --git a/benchmark/plugins/guru_math/naive_dapo.py b/benchmark/plugins/guru_math/naive_dapo.py
new file mode 100644
index 0000000000..b1dce7c597
--- /dev/null
+++ b/benchmark/plugins/guru_math/naive_dapo.py
@@ -0,0 +1,522 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
+
+import concurrent
+import math
+import os
+import re
+import resource
+
+import sympy
+from pylatexenc import latex2text
+from sympy.parsing import sympy_parser
+from verl.utils.reward_score.prime_math import math_normalize
+from verl.utils.reward_score.prime_math.grader import math_equal
+
+# Constants for normalization
+SUBSTITUTIONS = [
+ ("an ", ""),
+ ("a ", ""),
+ (".$", "$"),
+ ("\\$", ""),
+ (r"\ ", ""),
+ (" ", ""),
+ ("mbox", "text"),
+ (",\\text{and}", ","),
+ ("\\text{and}", ","),
+ ("\\text{m}", "\\text{}"),
+]
+
+REMOVED_EXPRESSIONS = [
+ "square",
+ "ways",
+ "integers",
+ "dollars",
+ "mph",
+ "inches",
+ "hours",
+ "km",
+ "units",
+ "\\ldots",
+ "sue",
+ "points",
+ "feet",
+ "minutes",
+ "digits",
+ "cents",
+ "degrees",
+ "cm",
+ "gm",
+ "pounds",
+ "meters",
+ "meals",
+ "edges",
+ "students",
+ "childrentickets",
+ "multiples",
+ "\\text{s}",
+ "\\text{.}",
+ "\\text{\ns}",
+ "\\text{}^2",
+ "\\text{}^3",
+ "\\text{\n}",
+ "\\text{}",
+ r"\mathrm{th}",
+ r"^\circ",
+ r"^{\circ}",
+ r"\;",
+ r",\!",
+ "{,}",
+ '"',
+ "\\dots",
+]
+
+
+def normalize_final_answer(final_answer: str) -> str:
+ """Normalize a final answer to a quantitative reasoning question.
+
+ Args:
+ final_answer: The answer string to normalize
+
+ Returns:
+ Normalized answer string
+ """
+ final_answer = final_answer.split("=")[-1]
+
+ # Apply substitutions and removals
+ for before, after in SUBSTITUTIONS:
+ final_answer = final_answer.replace(before, after)
+ for expr in REMOVED_EXPRESSIONS:
+ final_answer = final_answer.replace(expr, "")
+
+ # Extract and normalize LaTeX math
+ final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
+ final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
+ final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
+ final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
+ final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
+
+ # Normalize shorthand TeX:
+ # \fracab -> \frac{a}{b}
+ # \frac{abc}{bef} -> \frac{abc}{bef}
+ # \fracabc -> \frac{a}{b}c
+ # \sqrta -> \sqrt{a}
+ # \sqrtab -> sqrt{a}b
+ final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
+ final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
+ final_answer = final_answer.replace("$", "")
+
+ # Normalize numbers
+ if final_answer.replace(",", "").isdigit():
+ final_answer = final_answer.replace(",", "")
+
+ return final_answer.strip()
+
+
+# sympy might hang -- we don't care about trying to be lenient in these cases
+BAD_SUBSTRINGS = ["^{", "^("]
+BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"]
+TUPLE_CHARS = "()[]"
+
+
+def timeout(timeout_seconds: int = 8):
+ if os.name == "posix":
+ import signal
+
+ def decorator(func):
+ def handler(signum, frame):
+ raise TimeoutError("Operation timed out!")
+
+ def wrapper(*args, **kwargs):
+ old_handler = signal.getsignal(signal.SIGALRM)
+ signal.signal(signal.SIGALRM, handler)
+ signal.alarm(timeout_seconds)
+
+ try:
+ return func(*args, **kwargs)
+ finally:
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, old_handler)
+
+ return wrapper
+
+ return decorator
+ else:
+ raise NotImplementedError(f"Unsupported OS: {os.name}")
+
+
+def _sympy_parse(expr: str):
+ """Parses an expression with sympy."""
+ py_expr = expr.replace("^", "**")
+ return sympy_parser.parse_expr(
+ py_expr,
+ transformations=(
+ sympy_parser.standard_transformations
+ + (sympy_parser.implicit_multiplication_application,)
+ ),
+ )
+
+
+def _parse_latex(expr: str) -> str:
+ """Attempts to parse latex to an expression sympy can read."""
+ expr = expr.replace("\\tfrac", "\\frac")
+ expr = expr.replace("\\dfrac", "\\frac")
+ expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
+ expr = latex2text.LatexNodes2Text().latex_to_text(expr)
+
+ # Replace the specific characters that this parser uses.
+ expr = expr.replace("√", "sqrt")
+ expr = expr.replace("π", "pi")
+ expr = expr.replace("∞", "inf")
+ expr = expr.replace("∪", "U")
+ expr = expr.replace("·", "*")
+ expr = expr.replace("×", "*")
+
+ return expr.strip()
+
+
+def _is_float(num: str) -> bool:
+ try:
+ float(num)
+ return True
+ except ValueError:
+ return False
+
+
+def _is_int(x: float) -> bool:
+ try:
+ return abs(x - int(round(x))) <= 1e-7
+ except Exception:
+ return False
+
+
+def _is_frac(expr: str) -> bool:
+ return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
+
+
+def _str_is_int(x: str) -> bool:
+ try:
+ x = _strip_properly_formatted_commas(x)
+ x = float(x)
+ return abs(x - int(round(x))) <= 1e-7 # type: ignore
+ except Exception:
+ return False
+
+
+def _str_to_int(x: str) -> int:
+ x = x.replace(",", "")
+ x = float(x)
+ return int(x)
+
+
+def _inject_implicit_mixed_number(step: str):
+ """
+ Automatically make a mixed number evalable
+ e.g. 7 3/4 => 7+3/4
+ """
+ p1 = re.compile("([0-9]) +([0-9])")
+ step = p1.sub("\\1+\\2", step) # implicit mults
+ return step
+
+
+def _strip_properly_formatted_commas(expr: str):
+ # We want to be careful because we don't want to strip tuple commas
+ p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)")
+ while True:
+ next_expr = p1.sub("\\1\\3\\4", expr)
+ if next_expr == expr:
+ break
+ expr = next_expr
+ return next_expr
+
+
+def _normalize(expr: str) -> str:
+ """Normalize answer expressions."""
+ if expr is None:
+ return None
+
+ # Remove enclosing `\text{}`.
+ m = re.search(r"^\\text\{(?P.+?)\}$", expr)
+ if m is not None:
+ expr = m.group("text")
+
+ expr = expr.replace("\\%", "%")
+ expr = expr.replace("\\$", "$")
+ expr = expr.replace("$", "")
+ expr = expr.replace("%", "")
+ expr = expr.replace(" or ", " , ")
+ expr = expr.replace(" and ", " , ")
+
+ expr = expr.replace("million", "*10^6")
+ expr = expr.replace("billion", "*10^9")
+ expr = expr.replace("trillion", "*10^12")
+
+ for unit in [
+ "degree",
+ "cm",
+ "centimeter",
+ "meter",
+ "mile",
+ "second",
+ "minute",
+ "hour",
+ "day",
+ "week",
+ "month",
+ "year",
+ "foot",
+ "feet",
+ "inch",
+ "yard",
+ "liter",
+ ]:
+ expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
+ expr = re.sub("\\^ *\\\\circ", "", expr)
+
+ if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
+ expr = expr[1:-1]
+
+ expr = re.sub(",\\\\! *", "", expr)
+ if _is_float(expr) and _is_int(float(expr)):
+ expr = str(int(round(float(expr))))
+ if "\\" in expr:
+ try:
+ expr = _parse_latex(expr)
+ except Exception:
+ pass
+
+ # edge case with mixed numbers and negative signs
+ expr = re.sub("- *", "-", expr)
+
+ expr = _inject_implicit_mixed_number(expr)
+
+ # don't be case sensitive for text answers
+ expr = expr.lower()
+
+ if _str_is_int(expr):
+ expr = str(_str_to_int(expr))
+
+ return expr
+
+
+def count_unknown_letters_in_expr(expr: str):
+ expr = expr.replace("sqrt", "")
+ expr = expr.replace("frac", "")
+ letters_in_expr = set([x for x in expr if x.isalpha()])
+ return len(letters_in_expr)
+
+
+def should_allow_eval(expr: str):
+ # we don't want to try parsing unknown text or functions of more than two variables
+ if count_unknown_letters_in_expr(expr) > 2:
+ return False
+
+ for bad_string in BAD_SUBSTRINGS:
+ if bad_string in expr:
+ return False
+
+ for bad_regex in BAD_REGEXES:
+ if re.search(bad_regex, expr) is not None:
+ return False
+
+ return True
+
+
+# @timeout(timeout_seconds=10)
+def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
+ @timeout(timeout_seconds=10)
+ def check_equal():
+ memory_size = 1024**3
+ resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size))
+
+ expr = f"({ground_truth_normalized})-({given_normalized})"
+ if should_allow_eval(expr):
+ sympy_diff = _sympy_parse(expr)
+ simplified = sympy.simplify(sympy_diff)
+ if simplified == 0:
+ return True
+ return False
+
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(check_equal)
+ try:
+ return future.result(timeout=10)
+ except (concurrent.futures.TimeoutError, Exception):
+ future.cancel()
+ return False
+
+
+def split_tuple(expr: str):
+ """
+ Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
+ """
+ expr = _strip_properly_formatted_commas(expr)
+ if len(expr) == 0:
+ return []
+ if (
+ len(expr) > 2
+ and expr[0] in TUPLE_CHARS
+ and expr[-1] in TUPLE_CHARS
+ and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
+ ):
+ elems = [elem.strip() for elem in expr[1:-1].split(",")]
+ else:
+ elems = [expr]
+ return elems
+
+
+def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:
+ """
+ The answer will be considered correct if:
+ (a) it normalizes to the same string as the ground truth answer
+ OR
+ (b) sympy can simplify the difference between the expressions to 0
+ """
+ if given_answer is None:
+ return False
+
+ ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
+ given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
+
+ # be at least as lenient as mathd
+ if ground_truth_normalized_mathd == given_answer_normalized_mathd:
+ return True, given_answer_normalized_mathd
+
+ ground_truth_normalized = _normalize(ground_truth)
+ given_normalized = _normalize(given_answer)
+
+ if ground_truth_normalized is None:
+ return False, given_normalized
+
+ if ground_truth_normalized == given_normalized:
+ return True, given_normalized
+
+ if len(given_normalized) == 0:
+ return False, given_normalized
+
+ ground_truth_elems = split_tuple(ground_truth_normalized)
+ given_elems = split_tuple(given_normalized)
+
+ if len(ground_truth_elems) > 1 and (
+ ground_truth_normalized[0] != given_normalized[0]
+ or ground_truth_normalized[-1] != given_normalized[-1]
+ ):
+ is_correct = False
+ elif len(ground_truth_elems) != len(given_elems):
+ is_correct = False
+ else:
+ for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
+ if _is_frac(ground_truth_elem) and _is_frac(given_elem):
+ # if fractions aren't reduced, then shouldn't be marked as correct
+ # so, we don't want to allow sympy.simplify in this case
+ is_correct = ground_truth_elem == given_elem
+ elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
+ # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
+ is_correct = False
+ else:
+ is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
+ if not is_correct:
+ break
+
+ return is_correct, given_normalized
+
+
+def _last_boxed_only_string(string):
+ idx = string.rfind("\\boxed")
+ if idx < 0:
+ idx = string.rfind("\\fbox")
+ if idx < 0:
+ return None
+
+ i = idx
+ left_brace_idx = None
+ right_brace_idx = None
+ num_left_braces_open = 0
+ while i < len(string):
+ if string[i] == "{":
+ num_left_braces_open += 1
+ if left_brace_idx is None:
+ left_brace_idx = i
+ elif string[i] == "}":
+ num_left_braces_open -= 1
+ if num_left_braces_open == 0:
+ right_brace_idx = i
+ break
+
+ i += 1
+
+ if left_brace_idx is None or right_brace_idx is None:
+ return None
+
+ return string[left_brace_idx + 1 : right_brace_idx].strip()
+
+
+def match_answer(response):
+ is_matched = False
+ response = response.split("")[-1]
+
+ # Find boxed
+ ans_boxed = _last_boxed_only_string(response)
+ if ans_boxed:
+ is_matched = True
+ response = ans_boxed
+
+ return is_matched, response
+
+
+def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dict:
+ """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions
+
+ Args:
+ solution_str: The solution string
+ ground_truth: The ground truth answer
+ extra_info: dict with additional info for the score computation
+
+ Returns:
+ Reward score (1.0 for correct, 0.0 for incorrect)
+ """
+ # First assert intended generation and gt type
+ model_output = str(solution_str)
+ ground_truth = str(ground_truth)
+
+ # Extract answer from generated output
+ is_matched, extracted_model_output = match_answer(model_output)
+
+ # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score
+
+ # Verify the solution, first check simple comparisons.
+ correct, pred = grade_answer(extracted_model_output, ground_truth)
+
+ if not correct:
+ try:
+ if "\\pi" in extracted_model_output or "\\pi" in ground_truth:
+ equivs = []
+ for pi in [math.pi, 3.14]:
+ equivs.append(
+ math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi)
+ )
+ correct = any(equivs)
+ else:
+ correct = math_equal(extracted_model_output, ground_truth, timeout=True)
+ except Exception:
+ correct = False
+
+ reward = 1.0 if correct else 0.0
+ acc = correct
+
+ return {
+ "score": reward,
+ "acc": acc,
+ }
diff --git a/benchmark/plugins/guru_math/reward.py b/benchmark/plugins/guru_math/reward.py
new file mode 100644
index 0000000000..9cc22d9c73
--- /dev/null
+++ b/benchmark/plugins/guru_math/reward.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from trinity.common.rewards.math_reward import MathBoxedRewardFn
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS
+
+
+@REWARD_FUNCTIONS.register_module("math_boxed_reward_naive_dapo")
+class NaiveDapoRewardFn(MathBoxedRewardFn):
+ def __call__( # type: ignore
+ self,
+ response: str,
+ truth: Optional[str] = None,
+ with_think: Optional[bool] = False,
+ format_score_coef: Optional[float] = 0.1,
+ **kwargs,
+ ) -> dict[str, float]:
+ from .naive_dapo import compute_score
+
+ ret = compute_score(response, truth, None) # type: ignore
+ return {"accuracy": ret["score"], "format_score": 0}
+
+
+@REWARD_FUNCTIONS.register_module("math_boxed_reward_prime_math")
+class PrimeMathRewardFn(MathBoxedRewardFn):
+ def __call__( # type: ignore
+ self,
+ response: str,
+ truth: Optional[str] = None,
+ with_think: Optional[bool] = False,
+ format_score_coef: Optional[float] = 0.1,
+ **kwargs,
+ ) -> dict[str, float]:
+ from verl.utils.reward_score.prime_math import compute_score
+
+ ret = compute_score(response, truth)
+ return {"accuracy": ret["score"], "format_score": 0}
diff --git a/benchmark/scripts/gen-countdown-data.py b/benchmark/scripts/gen-countdown-data.py
deleted file mode 100644
index ffaf41f00e..0000000000
--- a/benchmark/scripts/gen-countdown-data.py
+++ /dev/null
@@ -1,115 +0,0 @@
-"""
-Modified from https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py
-Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target
-"""
-
-import argparse
-import json
-import os
-from random import randint, seed
-from typing import List, Tuple
-
-from datasets import load_dataset
-from tqdm import tqdm
-from verl.utils.hdfs_io import copy, makedirs
-
-
-def gen_dataset(
- num_samples: int,
- num_operands: int = 6,
- max_target: int = 1000,
- min_number: int = 1,
- max_number: int = 100,
- operations: List[str] = ["+", "-", "*", "/"],
- seed_value: int = 42,
-) -> List[Tuple]:
- """Generate dataset for countdown task.
-
- Args:
- num_samples: Number of samples to generate
- num_operands: Number of numbers provided in each sample
- max_target: Maximum value for target number
- min_number: Minimum value for provided numbers
- max_number: Maximum value for provided numbers
- operations: List of allowed operations
- seed_value: Random seed for reproducibility
-
- Returns:
- List of tuples containing (target, numbers, solution)
- """
- seed(seed_value)
- samples = []
-
- for _ in tqdm(range(num_samples)):
- # Generate random target
- target = randint(1, max_target)
-
- # Generate random numbers
- numbers = [randint(min_number, max_number) for _ in range(num_operands)]
-
- samples.append((target, numbers))
-
- return samples
-
-
-def make_prefix(dp):
- target = dp["target"]
- numbers = dp["nums"]
- system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer."""
- task_desc = f"""User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n"""
- final_prompt = f"{system_prompt}\n{task_desc}"
- return final_prompt
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--local_dir", default="~/data/countdown")
- parser.add_argument("--hdfs_dir", default=None)
- parser.add_argument("--num_samples", type=int, default=100000)
- parser.add_argument("--num_operands", type=int, default=6)
- parser.add_argument("--max_target", type=int, default=1000)
- parser.add_argument("--min_number", type=int, default=1)
- parser.add_argument("--max_number", type=int, default=100)
- parser.add_argument("--train_size", type=int, default=320000)
- parser.add_argument("--test_size", type=int, default=7680)
-
- args = parser.parse_args()
-
- data_source = "countdown"
- TRAIN_SIZE = args.train_size
- TEST_SIZE = args.test_size
-
- raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
-
- assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE
- train_dataset = raw_dataset.select(range(TRAIN_SIZE))
- test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE))
-
- def make_map_fn(split):
- def process_fn(example, idx):
- question = make_prefix(example)
- data = {
- "question": question,
- "answer": json.dumps(
- {
- "numbers": example["nums"],
- "target": example["target"],
- }
- ),
- }
- return data
-
- return process_fn
-
- train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
- test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
-
- local_dir = args.local_dir
- hdfs_dir = args.hdfs_dir
-
- train_dataset.to_json(os.path.join(local_dir, "train.jsonl"))
- test_dataset.to_json(os.path.join(local_dir, "test.jsonl"))
-
- if hdfs_dir is not None:
- makedirs(hdfs_dir)
- copy(src=local_dir, dst=hdfs_dir)
diff --git a/benchmark/scripts/gen_countdown_data.py b/benchmark/scripts/gen_countdown_data.py
new file mode 100644
index 0000000000..922cea6d1d
--- /dev/null
+++ b/benchmark/scripts/gen_countdown_data.py
@@ -0,0 +1,74 @@
+"""
+Modified from https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py
+Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target
+"""
+
+import argparse
+import json
+import os
+
+from datasets import load_dataset
+from verl.utils.hdfs_io import copy, makedirs
+
+DEFAULT_DATA_PATH = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "..", "data", "countdown"
+)
+
+
+def make_prefix(dp):
+ target = dp["target"]
+ numbers = dp["nums"]
+ system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer."""
+ task_desc = f"""User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n"""
+ final_prompt = f"{system_prompt}\n{task_desc}"
+ return final_prompt
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
+ parser.add_argument("--hdfs_dir", default=None)
+ parser.add_argument("--train_size", type=int, default=320000)
+ parser.add_argument("--test_size", type=int, default=7680)
+
+ args = parser.parse_args()
+
+ data_source = "countdown"
+ TRAIN_SIZE = args.train_size
+ TEST_SIZE = args.test_size
+
+ raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
+
+ assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE
+ train_dataset = raw_dataset.select(range(TRAIN_SIZE))
+ test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE))
+
+ def process_fn(example, idx):
+ question = make_prefix(example)
+ data = {
+ "question": question,
+ "answer": json.dumps(
+ {
+ "numbers": example["nums"],
+ "target": example["target"],
+ }
+ ),
+ }
+ return data
+
+ train_dataset = train_dataset.map(
+ function=process_fn, with_indices=True, remove_columns=train_dataset.column_names
+ )
+ test_dataset = test_dataset.map(
+ function=process_fn, with_indices=True, remove_columns=test_dataset.column_names
+ )
+
+ local_dir = args.local_dir
+ hdfs_dir = args.hdfs_dir
+
+ train_dataset.to_json(os.path.join(local_dir, "train.jsonl"))
+ test_dataset.to_json(os.path.join(local_dir, "test.jsonl"))
+
+ if hdfs_dir is not None:
+ makedirs(hdfs_dir)
+ copy(src=local_dir, dst=hdfs_dir)
diff --git a/benchmark/scripts/gen_guru_math_data.py b/benchmark/scripts/gen_guru_math_data.py
new file mode 100644
index 0000000000..4b7c2e5c84
--- /dev/null
+++ b/benchmark/scripts/gen_guru_math_data.py
@@ -0,0 +1,35 @@
+import argparse
+import os
+
+from datasets import load_dataset
+from huggingface_hub import hf_hub_download
+
+DEFAULT_DATA_PATH = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "..", "data", "guru_math"
+)
+
+
+def process_fn(example, idx):
+ data = {
+ "question": example["prompt"][0]["content"],
+ "ground_truth": example["reward_model"]["ground_truth"],
+ }
+ return data
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
+ args = parser.parse_args()
+
+ downloaded_file_path = hf_hub_download(
+ repo_id="LLM360/guru-RL-92k",
+ filename="train/math__combined_54.4k.parquet",
+ repo_type="dataset",
+ )
+ dataset = load_dataset("parquet", data_files=downloaded_file_path, split="train")
+ new_dataset = dataset.map(
+ function=process_fn, with_indices=True, remove_columns=dataset.column_names
+ ).shuffle()
+ os.makedirs(args.local_dir, exist_ok=True)
+ new_dataset.to_json(os.path.join(args.local_dir, "train.jsonl"))
diff --git a/docs/sphinx_doc/assets/guru-bench.png b/docs/sphinx_doc/assets/guru-bench.png
new file mode 100644
index 0000000000..aa22793b11
Binary files /dev/null and b/docs/sphinx_doc/assets/guru-bench.png differ
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index b157c343e7..24a8de0296 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -244,7 +244,7 @@ class Critic:
ppo_max_token_len_per_gpu: Optional[int] = None
forward_max_token_len_per_gpu: Optional[int] = None
ulysses_sequence_parallel_size: Optional[int] = None
- ppo_epochs: int = 0
+ ppo_epochs: int = 1
shuffle: bool = False
grad_clip: Optional[float] = None
cliprange_value: float = 0.0