diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 0812fb5e6e..da4bc3b54d 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock from tests.tools import get_unittest_dataset_config -from trinity.common.workflows import MathWorkflow, Workflow +from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow from trinity.common.workflows.workflow import Task @@ -134,6 +134,57 @@ def test_math_complex_workflow(self) -> None: self.assertEqual(len(experiences), 1) self.assertEqual(experiences[0].reward, 0.9) + def test_math_boxed_workflow(self) -> None: + model = MagicMock() + model.chat.return_value = [ + MockResponse(" balabalabala 99 \n \\boxed{36}"), + MockResponse("answer is \\boxed{36 }"), + MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"), + MockResponse(" balalaba \\boxed{35.00}"), + ] + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathBoxedWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + workflow_args={ + "with_think": False, + "format_score_coef": 0.2, + }, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"36", + }, + ) + workflow = task.to_workflow(model=model) + experiences = workflow.run() + self.assertEqual(experiences[0].reward, 1.0) + self.assertEqual(experiences[1].reward, 1.0) + self.assertEqual(experiences[2].reward, 1.0) + self.assertEqual(experiences[3].reward, 0.0) + task_new = Task( + workflow=MathBoxedWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + workflow_args={ + "with_think": True, + "format_score_coef": 0.2, + }, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"36", + }, + ) + workflow.reset(task_new) + workflow_new = task_new.to_workflow(model=model) + experiences = workflow_new.run() + self.assertEqual(experiences[0].reward, 1.0) + self.assertEqual(experiences[1].reward, 0.8) + self.assertEqual(experiences[2].reward, 0.8) + self.assertEqual(experiences[3].reward, 0.0) + def test_gsm8k_workflow(self) -> None: model = MagicMock() model.chat.return_value = [ diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index 1b5906b8be..822636482e 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -9,10 +9,12 @@ from math_verify import LatexExtractionConfig, parse, verify from trinity.utils.eval_utils import ( + compute_score, evaluate_equation, extract_solution, simple_answer_parser, validate_equation, + validate_think_pattern, ) from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -195,3 +197,33 @@ def __call__( return format_score except Exception as e: # noqa: F841 return format_score + + +@REWARD_FUNCTIONS.register_module("math_boxed_reward") +class MathBoxedRewardFn(RewardFn): + """A reward function that rewards for math task.""" + + def __init__( + self, + ) -> None: + pass + + def __call__( # type: ignore + self, + response: str, + prompt: Optional[str] = None, + truth: Optional[str] = None, + return_dict: Optional[bool] = False, + with_think: Optional[bool] = False, + format_score_coef: Optional[float] = 0.1, + ) -> Union[float, dict]: + accuracy_score = compute_score(response, truth) + + format_score = 0.0 + if with_think and not validate_think_pattern(response): + format_score = (format_score_coef or 0.1) * -1.0 + + if return_dict: + return {"accuracy": accuracy_score, "format_score": format_score} + + return accuracy_score + format_score diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index f5b1c9a7b9..9d54f108d0 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Workflow module""" +from .customized_math_workflows import MathBoxedWorkflow from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow @@ -14,4 +15,5 @@ "WebShopWorkflow", "AlfworldWorkflow", "SciWorldWorkflow", + "MathBoxedWorkflow", ] diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py new file mode 100644 index 0000000000..d71a5d2fb1 --- /dev/null +++ b/trinity/common/workflows/customized_math_workflows.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +"""We include the customized math workflows in this file.""" + +from dataclasses import asdict +from typing import List + +from trinity.common.experience import Experience +from trinity.common.rewards.reward_fn import MathBoxedRewardFn +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@WORKFLOWS.register_module("math_boxed_workflow") +class MathBoxedWorkflow(SimpleWorkflow): + """A workflow for math tasks that give answers in boxed format.""" + + def reset(self, task: Task): + self.format_args = task.format_args + self.system_prompt = task.format_args.system_prompt + self.reply_prefix = task.format_args.reply_prefix + + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.truth = task.truth + + # Rollout args + rollout_args = asdict(task.rollout_args) + self.rollout_args = rollout_args + self.is_eval = task.is_eval + + self.workflow_args = task.workflow_args + + self.use_base = self.workflow_args.get("use_base", False) + self.with_think = self.workflow_args.get("with_think", False) + self.format_score_coef = self.workflow_args.get("format_score_coef", 0.1) + + default_prompt = ( + """Please reason step by step, and put your final answer within \\boxed{}.""" + ) + + default_prompt_with_think = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: \n ...your reasoning process here... \n first. You should always include your final answer in \\boxed{} as closed-form results.""" + + if self.system_prompt is None: + if self.with_think: + self.system_prompt = default_prompt_with_think + else: + self.system_prompt = default_prompt + + self.reward_fn = MathBoxedRewardFn() + + def format_prompt(self): + prompt_text = "" + if self.system_prompt: + prompt_text += "System:" + self.system_prompt + prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" + else: + prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" + return prompt_text + + def run(self) -> List[Experience]: + # TODO: Optimize the generate function + if not self.use_base: + messages = self.format_messages() + else: + prompt_text = self.format_prompt() + + logger.debug("start chat") + if not self.use_base: + responses = self.model.chat(messages, **self.rollout_args) + else: + responses = self.model.generate([prompt_text], **self.rollout_args) + + for response in responses: + reward = MathBoxedRewardFn()( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + return_dict=self.is_eval, + with_think=self.with_think, + format_score_coef=self.format_score_coef, + ) + logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + if isinstance(reward, dict): + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward) + reward = sum(reward.values()) + response.reward = reward + return responses diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index e80afaf59b..c7ba647b50 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -83,3 +83,236 @@ def evaluate_equation(equation_str): return result except Exception as e: # noqa: F841 return None + + +def validate_think_pattern(text): + """Validate whether the tag is properly formatted.""" + start_tag = "" + end_tag = "" + + start_count = text.count(start_tag) + end_count = text.count(end_tag) + + if start_count == 1 and end_count == 1: + start_pos = text.find(start_tag) + end_pos = text.find(end_tag) + if start_pos < end_pos: + return True + return False + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def compute_score(solution_str, ground_truth) -> float: + retval = 0.0 + try: + string_in_last_boxed = last_boxed_only_string(solution_str) + if string_in_last_boxed is not None: + answer = remove_boxed(string_in_last_boxed) + if is_equiv(answer, ground_truth): + retval = 1.0 + except Exception as e: + print(e) + + return retval + + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1] + + return retval + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: # noqa: E722 + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: # noqa: E722 + return string + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string