diff --git a/benchmark/README.md b/benchmark/README.md index 28aff1a1bb..4fb67d06e5 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -69,17 +69,24 @@ The chart below shows performance based on this [commit](https://github.com/mode ![View Results](../docs/sphinx_doc/assets/gsm8k-bench.png) ### 2. Countdown -First generate data, then run the benchmark: +To reproduce this experiment: ```bash -# Step 1: Generate data -python benchmark/scripts/gen-countdown-data.py --local_dir /your/data/path -# Step 2: Run benchmark -python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct --taskset_path /your/data/path +python bench.py countdown --model_path /path/to/Qwen/Qwen2.5-1.5B-Instruct ``` #### Countdown Results The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/068da409d215bb2450d93b6b7a56740d4751669d). ![View Results](../docs/sphinx_doc/assets/countdown-bench.png) +### 3. Guru-Math +To reproduce this experiment: +```bash +python bench.py guru_math --model_path /path/to/Qwen/Qwen2.5-7B +``` + +#### Guru Results +The chart below shows performance based on this [commit](https://github.com/modelscope/Trinity-RFT/tree/fbf6c967bcd637bfd9f81fb4d7dd4961d7d5a407). +![View Results](../docs/sphinx_doc/assets/guru-bench.png) + *More benchmarks will be added soon!* --- diff --git a/benchmark/bench.py b/benchmark/bench.py index d4fbdf761f..6adf062d9d 100644 --- a/benchmark/bench.py +++ b/benchmark/bench.py @@ -1,6 +1,8 @@ import argparse +import importlib import os import subprocess +import sys import time import torch @@ -8,14 +10,14 @@ import yaml from trinity.algorithm.algorithm import ALGORITHM_TYPE -from trinity.common.constants import MODEL_PATH_ENV_VAR +from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle from trinity.utils.dlc_utils import get_dlc_env_vars def set_engine_num(config, args): config["cluster"]["node_num"] = args.node_num config["cluster"]["gpu_per_node"] = args.gpu_per_node - batch_size = config["buffer"]["batch_size"] + batch_size = config["buffer"]["batch_size"] * config["algorithm"]["repeat_times"] if config["mode"] == "train": return @@ -61,6 +63,83 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff): config["explorer"]["rollout_model"]["engine_num"] = opt_explorer_num +def check_taskset_path(dataset_name: str, taskset_path: str) -> str: + """Ensures the taskset path exists for the given dataset; generates it if necessary. + + This function checks whether `taskset_path` exists. If not, + it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create + the dataset at the default or provided location. The generator scripts are expected + to be located in the 'scripts/' subdirectory relative to this file. + + Args: + dataset_name: Name of the dataset (e.g., "countdown", "guru"). + Must be one of the supported datasets defined in `dataset_script_map`. + taskset_path: Path to the dataset. + + Returns: + str: The resolved path to the dataset. + + Raises: + ValueError: If the `dataset_name` is not supported. + FileNotFoundError: If the corresponding generator script does not exist. + ImportError: If the generator module fails to load. + AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'. + subprocess.CalledProcessError: If the generation script fails (due to check=True). + + Side Effects: + - May create directories and files on disk via the external generation script. + - Executes a subprocess to run the dataset generation script. + + Examples: + For dataset_name='guru_math' and taskset_path=None, this function will runs the + following command and generate the guru_math dataset to default location + (DEFAULT_DATA_PATH in scripts/gen_guru_math_data.py): + + ```bash + python scripts/gen_guru_math_data.py --local_dir DEFAULT_DATA_PATH + ``` + """ + if taskset_path: + if os.path.exists(taskset_path): + return taskset_path + if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k": + return taskset_path + + dataset_script_map = { + "countdown": "gen_countdown_data.py", + "guru_math": "gen_guru_math_data.py", + } + if dataset_name not in dataset_script_map: + raise ValueError( + f"Unsupported dataset: {dataset_name}. Please specify a valid taskset path." + ) + + base_dir = os.path.dirname(__file__) + script_filename = dataset_script_map[dataset_name] + script_module_name = script_filename[:-3] # remove .py + + script_file_path = os.path.join(base_dir, "scripts", script_filename) + if not os.path.exists(script_file_path): + raise FileNotFoundError(f"Generator script not found: {script_file_path}") + + spec = importlib.util.spec_from_file_location(script_module_name, script_file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load spec for module: {script_module_name}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if taskset_path is None: + if not hasattr(module, "DEFAULT_DATA_PATH"): + raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'") + taskset_path = module.DEFAULT_DATA_PATH + taskset_path = os.path.realpath(taskset_path) + + gen_script_path = os.path.join(base_dir, "scripts", script_filename) + subprocess.run([sys.executable, gen_script_path, "--local_dir", taskset_path], check=True) + + return taskset_path + + def prepare_configs(args, rank, current_time): base_path = os.path.dirname(os.path.abspath(__file__)) @@ -89,18 +168,19 @@ def prepare_configs(args, rank, current_time): ) if args.critic_lr: config["trainer"]["trainer_config"]["critic"]["optim"]["lr"] = args.critic_lr - config["buffer"]["explorer_input"]["taskset"]["path"] = ( - args.taskset_path - or os.environ.get("TASKSET_PATH") - or config["buffer"]["explorer_input"]["taskset"]["path"] + taskset_config = config["buffer"]["explorer_input"]["taskset"] + taskset_config["path"] = check_taskset_path( + args.dataset, + args.taskset_path or os.environ.get("TASKSET_PATH") or taskset_config["path"], ) - assert ( - config["buffer"]["explorer_input"]["taskset"]["path"] is not None - ), "Please specify taskset path." if args.lr: config["algorithm"]["optimizer"]["lr"] = args.lr if args.sync_interval: config["synchronizer"]["sync_interval"] = args.sync_interval + if args.sync_offset: + config["synchronizer"]["sync_offset"] = args.sync_offset + if args.sync_style: + config["synchronizer"]["sync_style"] = args.sync_style with open(config_path, "w") as f: yaml.dump(config, f, allow_unicode=True, sort_keys=False) @@ -131,7 +211,7 @@ def main(args): rank, current_time = 0, time.time() config_path = prepare_configs(args, rank, current_time) cmd_list = [ - "python", + sys.executable, "-m", "trinity.cli.launcher", "run", @@ -142,12 +222,21 @@ def main(args): dist.barrier() dist.destroy_process_group() cmd_list.append("--dlc") + + # load plugins + base_path = os.path.dirname(os.path.abspath(__file__)) + plugin_dir = os.path.join(base_path, "plugins", args.dataset) + if os.path.exists(plugin_dir): + cmd_list.append("--plugin-dir") + cmd_list.append(plugin_dir) + + # run command subprocess.run(cmd_list, check=True) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("dataset", type=str, choices=["gsm8k", "countdown", "openr1"]) + parser.add_argument("dataset", type=str.lower, choices=["gsm8k", "countdown", "guru_math"]) parser.add_argument( "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." ) @@ -191,5 +280,12 @@ def main(args): parser.add_argument( "--sync_interval", type=int, default=None, help="Specify the sync interval." ) + parser.add_argument("--sync_offset", type=int, default=None, help="Specify the sync offset.") + parser.add_argument( + "--sync_style", + type=str, + default=None, + choices=[sync_style.value for sync_style in SyncStyle], + ) args = parser.parse_args() main(args) diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index 96726c51e6..d81cfb6759 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -1,7 +1,7 @@ mode: both project: Trinity-RFT -group: countdown-bench -name: countdown-qwen2.5-1.5B +group: ${oc.env:TRINITY_GROUP,countdown-bench} +name: ${oc.env:TRINITY_NAME,countdown} checkpoint_root_dir: placeholder algorithm: algorithm_type: ppo @@ -72,102 +72,16 @@ trainer: total_steps: 1000 enable_preview: true grad_clip: 1.0 + max_token_len_per_gpu: 6400 trainer_config: - actor_rollout_ref: - hybrid_engine: true - model: - external_lib: null - override_config: {} - enable_gradient_checkpointing: true - use_remove_padding: true - actor: - strategy: fsdp - ppo_micro_batch_size_per_gpu: 4 - use_dynamic_bsz: true - ppo_max_token_len_per_gpu: 6400 - ppo_epochs: 1 - shuffle: false - ulysses_sequence_parallel_size: 1 - checkpoint: - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: false - optimizer_offload: false - fsdp_size: -1 - ref: - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: false - optimizer_offload: false - fsdp_size: -1 - log_prob_micro_batch_size_per_gpu: 8 - log_prob_use_dynamic_bsz: true - log_prob_max_token_len_per_gpu: 6400 - ulysses_sequence_parallel_size: 1 - custom_reward_function: - path: null - name: compute_score - algorithm: - kl_penalty: low_var_kl - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: - balance_batch: true - resume_mode: auto - resume_from_path: '' - critic_warmup: 0 - default_hdfs_dir: null - remove_previous_ckpt_in_save: false - del_local_ckpt_after_load: false - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null critic: - strategy: fsdp optim: lr: 1e-5 lr_warmup_steps_ratio: 0.0 warmup_style: constant - model: - override_config: {} - external_lib: null - enable_gradient_checkpointing: true - use_remove_padding: true - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: false - optimizer_offload: false - fsdp_size: -1 - ppo_micro_batch_size_per_gpu: 8 - forward_micro_batch_size_per_gpu: 8 - use_dynamic_bsz: true ppo_max_token_len_per_gpu: 12800 forward_max_token_len_per_gpu: 12800 - ulysses_sequence_parallel_size: 1 - ppo_epochs: 1 - shuffle: false - grad_clip: 1.0 cliprange_value: 0.5 - checkpoint: - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra monitor: monitor_type: wandb synchronizer: diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 9e602bfe52..a967589fe9 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -1,7 +1,7 @@ mode: both project: Trinity-RFT -group: gsm8k-bench -name: gsm8k-qwen2.5-1.5B +group: ${oc.env:TRINITY_GROUP,gsm8k-bench} +name: ${oc.env:TRINITY_NAME,gsm8k} checkpoint_root_dir: placeholder algorithm: algorithm_type: grpo diff --git a/benchmark/config/guru_math-template.yaml b/benchmark/config/guru_math-template.yaml new file mode 100644 index 0000000000..7538171826 --- /dev/null +++ b/benchmark/config/guru_math-template.yaml @@ -0,0 +1,74 @@ +mode: both +project: Trinity-RFT +group: ${oc.env:TRINITY_GROUP,guru_math-bench} +name: ${oc.env:TRINITY_NAME,guru_math} +checkpoint_root_dir: placeholder +model: + model_path: Qwen/Qwen2.5-7B + max_prompt_tokens: 4096 + max_response_tokens: 8192 +algorithm: + algorithm_type: grpo + repeat_times: 16 + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1e-6 + weight_decay: 0.1 + lr_warmup_steps: 80 + warmup_style: constant +cluster: + node_num: 1 + gpu_per_node: 8 +data_processor: + experience_pipeline: + save_input: false +buffer: + total_epochs: 1 + batch_size: 60 + explorer_input: + default_workflow_type: math_boxed_workflow + default_reward_fn_type: math_boxed_reward_naive_dapo + taskset: + name: math + storage_type: file + path: null + format: + prompt_key: question + response_key: ground_truth + system_prompt: "You are a helpful assistant. To answer a query from the user, please first thinks through the question step-by-step inside ..., then provides the final response to user." + reply_prefix: "" + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + eval_interval: 10 + runner_per_model: 8 + rollout_model: + engine_type: vllm_async + engine_num: 3 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: false + dtype: bfloat16 + seed: 42 +synchronizer: + sync_style: fixed + sync_method: nccl + sync_interval: 8 + sync_timeout: 2400 +trainer: + trainer_type: verl + save_interval: 80 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 24576 +monitor: + monitor_type: wandb diff --git a/benchmark/plugins/guru_math/naive_dapo.py b/benchmark/plugins/guru_math/naive_dapo.py new file mode 100644 index 0000000000..b1dce7c597 --- /dev/null +++ b/benchmark/plugins/guru_math/naive_dapo.py @@ -0,0 +1,522 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import concurrent +import math +import os +import re +import resource + +import sympy +from pylatexenc import latex2text +from sympy.parsing import sympy_parser +from verl.utils.reward_score.prime_math import math_normalize +from verl.utils.reward_score.prime_math.grader import math_equal + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +# sympy might hang -- we don't care about trying to be lenient in these cases +BAD_SUBSTRINGS = ["^{", "^("] +BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] +TUPLE_CHARS = "()[]" + + +def timeout(timeout_seconds: int = 8): + if os.name == "posix": + import signal + + def decorator(func): + def handler(signum, frame): + raise TimeoutError("Operation timed out!") + + def wrapper(*args, **kwargs): + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout_seconds) + + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator + else: + raise NotImplementedError(f"Unsupported OS: {os.name}") + + +def _sympy_parse(expr: str): + """Parses an expression with sympy.""" + py_expr = expr.replace("^", "**") + return sympy_parser.parse_expr( + py_expr, + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), + ) + + +def _parse_latex(expr: str) -> str: + """Attempts to parse latex to an expression sympy can read.""" + expr = expr.replace("\\tfrac", "\\frac") + expr = expr.replace("\\dfrac", "\\frac") + expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. + expr = latex2text.LatexNodes2Text().latex_to_text(expr) + + # Replace the specific characters that this parser uses. + expr = expr.replace("√", "sqrt") + expr = expr.replace("π", "pi") + expr = expr.replace("∞", "inf") + expr = expr.replace("∪", "U") + expr = expr.replace("·", "*") + expr = expr.replace("×", "*") + + return expr.strip() + + +def _is_float(num: str) -> bool: + try: + float(num) + return True + except ValueError: + return False + + +def _is_int(x: float) -> bool: + try: + return abs(x - int(round(x))) <= 1e-7 + except Exception: + return False + + +def _is_frac(expr: str) -> bool: + return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) + + +def _str_is_int(x: str) -> bool: + try: + x = _strip_properly_formatted_commas(x) + x = float(x) + return abs(x - int(round(x))) <= 1e-7 # type: ignore + except Exception: + return False + + +def _str_to_int(x: str) -> int: + x = x.replace(",", "") + x = float(x) + return int(x) + + +def _inject_implicit_mixed_number(step: str): + """ + Automatically make a mixed number evalable + e.g. 7 3/4 => 7+3/4 + """ + p1 = re.compile("([0-9]) +([0-9])") + step = p1.sub("\\1+\\2", step) # implicit mults + return step + + +def _strip_properly_formatted_commas(expr: str): + # We want to be careful because we don't want to strip tuple commas + p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") + while True: + next_expr = p1.sub("\\1\\3\\4", expr) + if next_expr == expr: + break + expr = next_expr + return next_expr + + +def _normalize(expr: str) -> str: + """Normalize answer expressions.""" + if expr is None: + return None + + # Remove enclosing `\text{}`. + m = re.search(r"^\\text\{(?P.+?)\}$", expr) + if m is not None: + expr = m.group("text") + + expr = expr.replace("\\%", "%") + expr = expr.replace("\\$", "$") + expr = expr.replace("$", "") + expr = expr.replace("%", "") + expr = expr.replace(" or ", " , ") + expr = expr.replace(" and ", " , ") + + expr = expr.replace("million", "*10^6") + expr = expr.replace("billion", "*10^9") + expr = expr.replace("trillion", "*10^12") + + for unit in [ + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", + ]: + expr = re.sub(rf"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub("\\^ *\\\\circ", "", expr) + + if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": + expr = expr[1:-1] + + expr = re.sub(",\\\\! *", "", expr) + if _is_float(expr) and _is_int(float(expr)): + expr = str(int(round(float(expr)))) + if "\\" in expr: + try: + expr = _parse_latex(expr) + except Exception: + pass + + # edge case with mixed numbers and negative signs + expr = re.sub("- *", "-", expr) + + expr = _inject_implicit_mixed_number(expr) + + # don't be case sensitive for text answers + expr = expr.lower() + + if _str_is_int(expr): + expr = str(_str_to_int(expr)) + + return expr + + +def count_unknown_letters_in_expr(expr: str): + expr = expr.replace("sqrt", "") + expr = expr.replace("frac", "") + letters_in_expr = set([x for x in expr if x.isalpha()]) + return len(letters_in_expr) + + +def should_allow_eval(expr: str): + # we don't want to try parsing unknown text or functions of more than two variables + if count_unknown_letters_in_expr(expr) > 2: + return False + + for bad_string in BAD_SUBSTRINGS: + if bad_string in expr: + return False + + for bad_regex in BAD_REGEXES: + if re.search(bad_regex, expr) is not None: + return False + + return True + + +# @timeout(timeout_seconds=10) +def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): + @timeout(timeout_seconds=10) + def check_equal(): + memory_size = 1024**3 + resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size)) + + expr = f"({ground_truth_normalized})-({given_normalized})" + if should_allow_eval(expr): + sympy_diff = _sympy_parse(expr) + simplified = sympy.simplify(sympy_diff) + if simplified == 0: + return True + return False + + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(check_equal) + try: + return future.result(timeout=10) + except (concurrent.futures.TimeoutError, Exception): + future.cancel() + return False + + +def split_tuple(expr: str): + """ + Split the elements in a tuple/interval, while handling well-formatted commas in large numbers + """ + expr = _strip_properly_formatted_commas(expr) + if len(expr) == 0: + return [] + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): + elems = [elem.strip() for elem in expr[1:-1].split(",")] + else: + elems = [expr] + return elems + + +def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: + """ + The answer will be considered correct if: + (a) it normalizes to the same string as the ground truth answer + OR + (b) sympy can simplify the difference between the expressions to 0 + """ + if given_answer is None: + return False + + ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) + given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) + + # be at least as lenient as mathd + if ground_truth_normalized_mathd == given_answer_normalized_mathd: + return True, given_answer_normalized_mathd + + ground_truth_normalized = _normalize(ground_truth) + given_normalized = _normalize(given_answer) + + if ground_truth_normalized is None: + return False, given_normalized + + if ground_truth_normalized == given_normalized: + return True, given_normalized + + if len(given_normalized) == 0: + return False, given_normalized + + ground_truth_elems = split_tuple(ground_truth_normalized) + given_elems = split_tuple(given_normalized) + + if len(ground_truth_elems) > 1 and ( + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] + ): + is_correct = False + elif len(ground_truth_elems) != len(given_elems): + is_correct = False + else: + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): + if _is_frac(ground_truth_elem) and _is_frac(given_elem): + # if fractions aren't reduced, then shouldn't be marked as correct + # so, we don't want to allow sympy.simplify in this case + is_correct = ground_truth_elem == given_elem + elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): + # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) + is_correct = False + else: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + if not is_correct: + break + + return is_correct, given_normalized + + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1 : right_brace_idx].strip() + + +def match_answer(response): + is_matched = False + response = response.split("")[-1] + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + return is_matched, response + + +def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> dict: + """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + extra_info: dict with additional info for the score computation + + Returns: + Reward score (1.0 for correct, 0.0 for incorrect) + """ + # First assert intended generation and gt type + model_output = str(solution_str) + ground_truth = str(ground_truth) + + # Extract answer from generated output + is_matched, extracted_model_output = match_answer(model_output) + + # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score + + # Verify the solution, first check simple comparisons. + correct, pred = grade_answer(extracted_model_output, ground_truth) + + if not correct: + try: + if "\\pi" in extracted_model_output or "\\pi" in ground_truth: + equivs = [] + for pi in [math.pi, 3.14]: + equivs.append( + math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi) + ) + correct = any(equivs) + else: + correct = math_equal(extracted_model_output, ground_truth, timeout=True) + except Exception: + correct = False + + reward = 1.0 if correct else 0.0 + acc = correct + + return { + "score": reward, + "acc": acc, + } diff --git a/benchmark/plugins/guru_math/reward.py b/benchmark/plugins/guru_math/reward.py new file mode 100644 index 0000000000..9cc22d9c73 --- /dev/null +++ b/benchmark/plugins/guru_math/reward.py @@ -0,0 +1,36 @@ +from typing import Optional + +from trinity.common.rewards.math_reward import MathBoxedRewardFn +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS + + +@REWARD_FUNCTIONS.register_module("math_boxed_reward_naive_dapo") +class NaiveDapoRewardFn(MathBoxedRewardFn): + def __call__( # type: ignore + self, + response: str, + truth: Optional[str] = None, + with_think: Optional[bool] = False, + format_score_coef: Optional[float] = 0.1, + **kwargs, + ) -> dict[str, float]: + from .naive_dapo import compute_score + + ret = compute_score(response, truth, None) # type: ignore + return {"accuracy": ret["score"], "format_score": 0} + + +@REWARD_FUNCTIONS.register_module("math_boxed_reward_prime_math") +class PrimeMathRewardFn(MathBoxedRewardFn): + def __call__( # type: ignore + self, + response: str, + truth: Optional[str] = None, + with_think: Optional[bool] = False, + format_score_coef: Optional[float] = 0.1, + **kwargs, + ) -> dict[str, float]: + from verl.utils.reward_score.prime_math import compute_score + + ret = compute_score(response, truth) + return {"accuracy": ret["score"], "format_score": 0} diff --git a/benchmark/scripts/gen-countdown-data.py b/benchmark/scripts/gen-countdown-data.py deleted file mode 100644 index ffaf41f00e..0000000000 --- a/benchmark/scripts/gen-countdown-data.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Modified from https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py -Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target -""" - -import argparse -import json -import os -from random import randint, seed -from typing import List, Tuple - -from datasets import load_dataset -from tqdm import tqdm -from verl.utils.hdfs_io import copy, makedirs - - -def gen_dataset( - num_samples: int, - num_operands: int = 6, - max_target: int = 1000, - min_number: int = 1, - max_number: int = 100, - operations: List[str] = ["+", "-", "*", "/"], - seed_value: int = 42, -) -> List[Tuple]: - """Generate dataset for countdown task. - - Args: - num_samples: Number of samples to generate - num_operands: Number of numbers provided in each sample - max_target: Maximum value for target number - min_number: Minimum value for provided numbers - max_number: Maximum value for provided numbers - operations: List of allowed operations - seed_value: Random seed for reproducibility - - Returns: - List of tuples containing (target, numbers, solution) - """ - seed(seed_value) - samples = [] - - for _ in tqdm(range(num_samples)): - # Generate random target - target = randint(1, max_target) - - # Generate random numbers - numbers = [randint(min_number, max_number) for _ in range(num_operands)] - - samples.append((target, numbers)) - - return samples - - -def make_prefix(dp): - target = dp["target"] - numbers = dp["nums"] - system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.""" - task_desc = f"""User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n""" - final_prompt = f"{system_prompt}\n{task_desc}" - return final_prompt - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", default="~/data/countdown") - parser.add_argument("--hdfs_dir", default=None) - parser.add_argument("--num_samples", type=int, default=100000) - parser.add_argument("--num_operands", type=int, default=6) - parser.add_argument("--max_target", type=int, default=1000) - parser.add_argument("--min_number", type=int, default=1) - parser.add_argument("--max_number", type=int, default=100) - parser.add_argument("--train_size", type=int, default=320000) - parser.add_argument("--test_size", type=int, default=7680) - - args = parser.parse_args() - - data_source = "countdown" - TRAIN_SIZE = args.train_size - TEST_SIZE = args.test_size - - raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train") - - assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE - train_dataset = raw_dataset.select(range(TRAIN_SIZE)) - test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE)) - - def make_map_fn(split): - def process_fn(example, idx): - question = make_prefix(example) - data = { - "question": question, - "answer": json.dumps( - { - "numbers": example["nums"], - "target": example["target"], - } - ), - } - return data - - return process_fn - - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) - - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - train_dataset.to_json(os.path.join(local_dir, "train.jsonl")) - test_dataset.to_json(os.path.join(local_dir, "test.jsonl")) - - if hdfs_dir is not None: - makedirs(hdfs_dir) - copy(src=local_dir, dst=hdfs_dir) diff --git a/benchmark/scripts/gen_countdown_data.py b/benchmark/scripts/gen_countdown_data.py new file mode 100644 index 0000000000..922cea6d1d --- /dev/null +++ b/benchmark/scripts/gen_countdown_data.py @@ -0,0 +1,74 @@ +""" +Modified from https://github.com/Jiayi-Pan/TinyZero/blob/main/examples/data_preprocess/countdown.py +Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target +""" + +import argparse +import json +import os + +from datasets import load_dataset +from verl.utils.hdfs_io import copy, makedirs + +DEFAULT_DATA_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "data", "countdown" +) + + +def make_prefix(dp): + target = dp["target"] + numbers = dp["nums"] + system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.""" + task_desc = f"""User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n""" + final_prompt = f"{system_prompt}\n{task_desc}" + return final_prompt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH) + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--train_size", type=int, default=320000) + parser.add_argument("--test_size", type=int, default=7680) + + args = parser.parse_args() + + data_source = "countdown" + TRAIN_SIZE = args.train_size + TEST_SIZE = args.test_size + + raw_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train") + + assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE + train_dataset = raw_dataset.select(range(TRAIN_SIZE)) + test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE)) + + def process_fn(example, idx): + question = make_prefix(example) + data = { + "question": question, + "answer": json.dumps( + { + "numbers": example["nums"], + "target": example["target"], + } + ), + } + return data + + train_dataset = train_dataset.map( + function=process_fn, with_indices=True, remove_columns=train_dataset.column_names + ) + test_dataset = test_dataset.map( + function=process_fn, with_indices=True, remove_columns=test_dataset.column_names + ) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_json(os.path.join(local_dir, "train.jsonl")) + test_dataset.to_json(os.path.join(local_dir, "test.jsonl")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) diff --git a/benchmark/scripts/gen_guru_math_data.py b/benchmark/scripts/gen_guru_math_data.py new file mode 100644 index 0000000000..4b7c2e5c84 --- /dev/null +++ b/benchmark/scripts/gen_guru_math_data.py @@ -0,0 +1,35 @@ +import argparse +import os + +from datasets import load_dataset +from huggingface_hub import hf_hub_download + +DEFAULT_DATA_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "data", "guru_math" +) + + +def process_fn(example, idx): + data = { + "question": example["prompt"][0]["content"], + "ground_truth": example["reward_model"]["ground_truth"], + } + return data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH) + args = parser.parse_args() + + downloaded_file_path = hf_hub_download( + repo_id="LLM360/guru-RL-92k", + filename="train/math__combined_54.4k.parquet", + repo_type="dataset", + ) + dataset = load_dataset("parquet", data_files=downloaded_file_path, split="train") + new_dataset = dataset.map( + function=process_fn, with_indices=True, remove_columns=dataset.column_names + ).shuffle() + os.makedirs(args.local_dir, exist_ok=True) + new_dataset.to_json(os.path.join(args.local_dir, "train.jsonl")) diff --git a/docs/sphinx_doc/assets/guru-bench.png b/docs/sphinx_doc/assets/guru-bench.png new file mode 100644 index 0000000000..aa22793b11 Binary files /dev/null and b/docs/sphinx_doc/assets/guru-bench.png differ diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index b157c343e7..24a8de0296 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -244,7 +244,7 @@ class Critic: ppo_max_token_len_per_gpu: Optional[int] = None forward_max_token_len_per_gpu: Optional[int] = None ulysses_sequence_parallel_size: Optional[int] = None - ppo_epochs: int = 0 + ppo_epochs: int = 1 shuffle: bool = False grad_clip: Optional[float] = None cliprange_value: float = 0.0