diff --git a/pyproject.toml b/pyproject.toml index 4abb70446a..f55f9118a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "openai", "jsonlines", "sortedcontainers", + "word2number", ] [project.scripts] @@ -66,7 +67,7 @@ dev = [ "pytest>=8.0.0", "pytest-json-ctrf", "parameterized", - "matplotlib" + "matplotlib", ] doc = [ diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 4f965c8a2e..293a9fc0c8 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -12,6 +12,7 @@ from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( MathBoxedWorkflow, + MathEvalWorkflow, MathRMWorkflow, MathWorkflow, Workflow, @@ -274,6 +275,36 @@ def test_rm_gallery_workflow(self) -> None: self.assertEqual(experiences[2].reward, 1.0) self.assertEqual(experiences[3].reward, 0.0) + def test_math_eval_workflow(self) -> None: + model = MagicMock() + model.chat.return_value = [ + MockResponse("My step-by-step reasoning leads to the answer \boxed{36}"), + MockResponse("Here is the answer of \boxed{36.0}"), + MockResponse("I made a mistake, the answer is \boxed{42}"), + MockResponse("The answer is 36, but I forgot the box."), + ] + + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathEvalWorkflow, + is_eval=True, + format_args=taskset_config.format, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: "36", + }, + ) + + workflow = task.to_workflow(model=model) + experiences = workflow.run() + self.assertEqual(len(experiences), 4) + expected_accuracies = [1.0, 1.0, 0.0, 0.0] + for i, (exp, expected_acc) in enumerate(zip(experiences, expected_accuracies)): + with self.subTest(f"Response {i}"): + self.assertEqual(exp.reward, 0.0) + assert exp.metrics is not None, f"Metrics for response {i} should not be None" + self.assertEqual(exp.metrics["accuracy"], expected_acc) + def test_workflow_resettable(self) -> None: model = MagicMock() json_task = Task( diff --git a/tests/template/config.yaml b/tests/template/config.yaml index aaca7ff0a8..21b0f63c2b 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -34,6 +34,7 @@ buffer: path: 'placeholder' split: 'train' default_workflow_type: '' + default_eval_workflow_type: '' default_reward_fn_type: '' explorer: eval_interval: 100 diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml index 018bc6ccea..bf474d3748 100644 --- a/tests/test_data/template.yaml +++ b/tests/test_data/template.yaml @@ -11,6 +11,7 @@ buffer: storage_type: file path: '' default_workflow_type: '' + default_eval_workflow_type: '' default_reward_fn_type: '' explorer: runner_num: 8 diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py new file mode 100644 index 0000000000..8105b692ce --- /dev/null +++ b/tests/utils/eval_utils_test.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +"""Test for the evaluation utils module.""" + +import unittest + +from trinity.utils.eval_utils import is_equiv +from trinity.utils.math_eval_utils import extract_answer, verify_math_answer + + +class TestMathEvalUtils(unittest.TestCase): + def test_extract_answer(self): + test_cases = [ + ("The answer is \\boxed{42}", "42", "Basic boxed extraction"), + ("The result is \\boxed{\\frac{1}{2}}", "\\frac{1}{2}", "Boxed with LaTeX"), + ("Therefore, the final answer is 100.", "100", "English 'answer is' extraction"), + ("My final answer is: 3.14", "3.14", "English 'answer is' with colon"), + ("所以,答案是x^2", "x^2", "Chinese 'answer is' extraction"), + ( + "The cost is 10 dollars and the profit is 20 dollars.", + "20", + "Extract the last number", + ), + ( + "There are 1,000 apples and 2,000 oranges.", + "2000", + "Extract the last number with commas", + ), + ("The probability is 0.75.", "0.75", "Extract the last decimal"), + ("This sentence has no answer.", None, "No answer case"), + ("The box is empty \\boxed{}", None, "Empty boxed"), + (12345, None, "Input is not a string"), + ] + + for i, (input_str, expected_output, description) in enumerate(test_cases): + with self.subTest(f"Case {i+1}: {description}"): + actual_output = extract_answer(input_str) + self.assertEqual( + actual_output, + expected_output, + f"Failed on input: '{input_str}'\nExpected: '{expected_output}', Got: '{actual_output}'", + ) + + def test_verify_math_answer(self): + test_cases = [ + ("The answer is \\boxed{42}", "42", True, "Simple integer equality"), + ("The result is 1,000.", "1000", True, "Number with commas"), + ("The answer is -50.", "-50", True, "Negative number equality"), + ("The solution is 5", "x=5", True, "Equivalence of value and equation"), + ("The answer is \\boxed{42}", "43", False, "Simple numerical inequality"), + ("The answer is \\boxed{x+1}", "x-1", False, "Symbolic expression inequality"), + ( + "The matrix is \\boxed{\\begin{pmatrix}1 & 1 \\\\ 0 & 1\\end{pmatrix}}", + "\\begin{pmatrix}1&0\\\\0&1\\end{pmatrix}", + False, + "Matrix inequality", + ), + ("The speed is 50 km/h", "50", True, "Judgment after stripping units"), + ] + + for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases): + with self.subTest(f"Case {i+1}: {description}"): + accuracy, details = verify_math_answer(response, ground_truth) + is_correct = accuracy == 1.0 + self.assertEqual( + is_correct, + expected_correct, + f"Failed on response: '{response}' with truth: '{ground_truth}'\n" + f"Expected correct: {expected_correct}, Got: {is_correct}\nDetails: {details}", + ) + + +if __name__ == "__main__": + unittest.main() + + +class TestEvalUtils(unittest.TestCase): + def test_is_equiv(self): + test_cases = [ + # str1, str2, expected_output, description + (" 123 ", "123", True, "Equivalence with whitespace"), + ("50%", "50", True, "Equivalence with percentage sign"), + ("$50", "50", True, "Equivalence with dollar sign"), + ("hello", "world", False, "Basic inequality"), + ("123", "1234", False, "Numerical inequality"), + (None, None, True, "Both inputs are None"), + ("Some string", None, False, "One input is None (str1)"), + (None, "Some string", False, "One input is None (str2)"), + ] + + for i, (str1, str2, expected_output, description) in enumerate(test_cases): + with self.subTest(f"Case {i+1}: {description}"): + actual_output = is_equiv(str1, str2) + self.assertEqual( + actual_output, + expected_output, + f"Failed on inputs: ('{str1}', '{str2}')\nExpected: {expected_output}, Got: {actual_output}", + ) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index a30f58ee8c..a03baa7f7c 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -288,6 +288,9 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.task_type = meta.task_type self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore + self.default_eval_workflow_cls = None + if getattr(meta, "default_eval_workflow_type", None): + self.default_eval_workflow_cls = WORKFLOWS.get(meta.default_eval_workflow_type) self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore def read( @@ -297,11 +300,14 @@ def read( tasks = [] samples = self.dataset.read_batch(batch_size) for sample in samples: - workflow_class = ( - WORKFLOWS.get(sample[self.workflow_key]) - if self.workflow_key in sample - else self.default_workflow_cls - ) + if self.task_type == TaskType.EVAL and self.default_eval_workflow_cls: + workflow_class = self.default_eval_workflow_cls + else: + workflow_class = ( + WORKFLOWS.get(sample[self.workflow_key]) + if self.workflow_key in sample + else self.default_workflow_cls + ) reward_fn = ( REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) if self.reward_fn_key in sample diff --git a/trinity/common/config.py b/trinity/common/config.py index 9714b521ab..a63ce54eaf 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -98,6 +98,7 @@ class StorageConfig: # used for rollout tasks default_workflow_type: Optional[str] = None + default_eval_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) @@ -276,6 +277,7 @@ class ExplorerInput: eval_tasksets: List[StorageConfig] = field(default_factory=list) # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets` default_workflow_type: Optional[str] = None + default_eval_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None system_prompt: Optional[str] = None reply_prefix: Optional[str] = None @@ -479,6 +481,10 @@ def _check_buffer(self) -> None: # noqa: C901 self.buffer.explorer_input.taskset.default_workflow_type = ( self.buffer.explorer_input.default_workflow_type ) + if self.buffer.explorer_input.taskset.default_eval_workflow_type is None: + self.buffer.explorer_input.taskset.default_eval_workflow_type = ( + self.buffer.explorer_input.default_eval_workflow_type + ) if self.buffer.explorer_input.taskset.default_reward_fn_type is None: self.buffer.explorer_input.taskset.default_reward_fn_type = ( self.buffer.explorer_input.default_reward_fn_type @@ -504,6 +510,10 @@ def _check_buffer(self) -> None: # noqa: C901 dataset.name = f"eval_taskset_{idx}" if dataset.default_workflow_type is None: dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type + if dataset.default_eval_workflow_type is None: + dataset.default_eval_workflow_type = ( + self.buffer.explorer_input.default_eval_workflow_type + ) if dataset.default_reward_fn_type is None: dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type if dataset.format.system_prompt is None: diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index b2d418126c..5a598f506f 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -5,6 +5,7 @@ from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow +from .eval_workflow import MathEvalWorkflow from .math_rm_workflow import MathRMWorkflow from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow @@ -20,4 +21,5 @@ "MathBoxedWorkflow", "MathRMWorkflow", "ToolCallWorkflow", + "MathEvalWorkflow", ] diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py new file mode 100644 index 0000000000..15fc9a047d --- /dev/null +++ b/trinity/common/workflows/eval_workflow.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +"""Evaluation Workflow Class""" + +from dataclasses import asdict +from typing import List, Optional + +import openai + +from trinity.common.config import GenerationConfig +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.utils.log import get_logger +from trinity.utils.math_eval_utils import verify_math_answer + +logger = get_logger(__name__) + + +@WORKFLOWS.register_module("math_eval_workflow") +class MathEvalWorkflow(Workflow): + """ + A workflow for standard math evaluation. + + The evaluation standard and prompting style are follow the Qwen2.5-Math + model's evaluation methodology. For more details on their approach, see: + https://github.com/QwenLM/Qwen2.5-Math + """ + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + + self.raw_task = task.raw_task + self.truth = task.truth + + # TODO: customize the config in the yaml + self.eval_gen_args = asdict(GenerationConfig(temperature=0.6, top_p=0.8, logprobs=0, n=1)) + + @property + def resettable(self): + return False + + def format_messages(self): + """Format message for the evaluation of qwen_boxed type.""" + if not self.raw_task or "question" not in self.raw_task: + raise ValueError("Raw task data must contain a 'question' field for MathEvalWorkflow.") + + problem_input = self.raw_task["question"] + + system_prompt = "You are a helpful assistant." + user_prompt = f"{problem_input}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + return messages + + def run(self) -> List[Experience]: + messages = self.format_messages() + + responses: List[Experience] = self.model.chat(messages, **self.eval_gen_args) + + for response in responses: + if response.response_text is None or self.task.truth is None: + continue + + accuracy, _ = verify_math_answer( + response_text=response.response_text, ground_truth=self.task.truth + ) + + acc_metrics = {"accuracy": accuracy} + if response.metrics is None: + response.metrics = {} + response.metrics.update(acc_metrics) + + return responses diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 4f6428f1ba..cd963f1c5c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -276,6 +276,12 @@ def eval(self): self.logger.warning("No evaluation data samples. Skip evaluation.") return self.logger.info(f"Evaluation at step {self.explore_step_num} started.") + + if self.config.buffer.explorer_input.default_eval_workflow_type: + self.logger.info( + f"Use '{self.config.buffer.explorer_input.default_eval_workflow_type}' for evaluation." + ) + for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: self.logger.info( f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started." diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 44adba0ab9..7a11e9b241 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -142,7 +142,9 @@ def beginner_mode(self): if st.session_state["sft_warmup_steps"] > 0: self.get_configs("sft_warmup_dataset_args") - self.get_configs("default_workflow_type", "default_reward_fn_type") + self.get_configs( + "default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type" + ) self.get_configs( "actor_ppo_micro_batch_size_per_gpu", @@ -166,7 +168,9 @@ def _expert_model_part(self): def _expert_buffer_part(self): self.get_configs("total_epochs", "train_batch_size") - self.get_configs("default_workflow_type", "default_reward_fn_type") + self.get_configs( + "default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type" + ) self.get_configs("system_prompt") self.get_configs("reply_prefix") @@ -544,6 +548,7 @@ def _gen_buffer_config(self): }, "eval_tasksets": [], "default_workflow_type": st.session_state["default_workflow_type"], + "default_eval_workflow_type": st.session_state["default_eval_workflow_type"], "default_reward_fn_type": st.session_state["default_reward_fn_type"], "system_prompt": st.session_state["system_prompt"], "reply_prefix": st.session_state["reply_prefix"], diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index c7ba647b50..3308825e81 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- import regex as re +from trinity.utils.math_eval_utils import strip_string + ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") INVALID_ANS = "[invalid]" @@ -177,55 +179,6 @@ def last_boxed_only_string(string): return retval -# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -def fix_fracs(string): - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except: # noqa: E722 - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - -# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -def fix_a_slash_b(string): - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - a = int(a) - b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except: # noqa: E722 - return string - - # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py def remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units @@ -235,84 +188,3 @@ def remove_right_units(string): return splits[0] else: return string - - -# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -def fix_sqrt(string): - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - -# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py -def strip_string(string): - # linebreaks - string = string.replace("\n", "") - - # remove inverse spaces - string = string.replace("\\!", "") - - # replace \\ with \ - string = string.replace("\\\\", "\\") - - # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # remove dollar signs - string = string.replace("\\$", "") - - # remove units (on the right) - string = remove_right_units(string) - - # remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") # noqa: W605 - - # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - # if empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # to consider: get rid of e.g. "k = " or "q = " at beginning - if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: - string = string.split("=")[1] - - # fix sqrt3 --> sqrt{3} - string = fix_sqrt(string) - - # remove spaces - string = string.replace(" ", "") - - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} - string = fix_fracs(string) - - # manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y - string = fix_a_slash_b(string) - - return string diff --git a/trinity/utils/math_eval_utils.py b/trinity/utils/math_eval_utils.py new file mode 100644 index 0000000000..17de53fffd --- /dev/null +++ b/trinity/utils/math_eval_utils.py @@ -0,0 +1,558 @@ +# -*- coding: utf-8 -*- +""" +Utility functions for strictly parsing and evaluating mathematical answers. + +This module is a modified and simplified version of the official evaluation code +for Qwen2.5-Math, designed for easier standalone use. + +Original source: https://github.com/QwenLM/Qwen2.5-Math + +Key modifications include: +1. Retained only the core parsing logic for the common `qwen_boxed` prompt format. +2. Consolidated essential parsing and evaluation functions from multiple files + into this single module. +3. Simplified benchmark handling and conditional logic for broader applicability. +4. Simplified or removed calls to external tools like TIR. +""" + +import re +from math import isclose +from typing import Any, Dict, Optional, Tuple + +import sympy +from sympy import Matrix, N, simplify, sympify +from sympy.parsing.latex import parse_latex +from word2number import w2n + + +def verify_math_answer(response_text: str, ground_truth: str) -> Tuple[float, Dict[str, Any]]: + """Strictly compare the equality of response and groundtruth.""" + # Parse the response + parsed_prediction = extract_answer(response_text) + + # Parse the ground truth + parsed_truth = extract_answer(str(ground_truth)) + + is_correct = math_equal(prediction=parsed_prediction, reference=parsed_truth) + + accuracy = 1.0 if is_correct else 0.0 + + eval_details = { + "parsed_prediction": parsed_prediction, + "ground_truth": parsed_truth, + "is_correct": is_correct, + } + + return accuracy, eval_details + + +def extract_answer(response_text: str) -> Optional[str]: + """Extract the equation from the string.""" + if not isinstance(response_text, str): + return None + + # Extract '\boxed{...}' + if "boxed" in response_text: + ans_part = response_text.split("boxed")[-1] + if not ans_part: + return None + + if ans_part.startswith("{"): + stack = 1 + extracted_ans = "" + for char in ans_part[1:]: + if char == "{": + stack += 1 + extracted_ans += char + elif char == "}": + stack -= 1 + if stack == 0: + break + extracted_ans += char + else: + extracted_ans += char + + if stack == 0: + return strip_string(extracted_ans) + + match = re.search(r"\{?([^$}]+)\}?", ans_part) + if match: + return strip_string(match.group(1)) + + # Extract 'answer is ...' + search_patterns = [r"(?:final|the)\s+answer\s+is\s*:?\s*(.+)", r"答案是\s*:?\s*(.+)"] + for pattern in search_patterns: + match = re.search(pattern, response_text, re.IGNORECASE) + if match: + pred = strip_string(match.group(1)) + if pred and pred.endswith("."): + pred = pred[:-1] + return pred + + # Extract the last number + text_no_commas = response_text.replace(",", "") + numeric_finds = re.findall(r"[-+]?\d*\.?\d+", text_no_commas) + if numeric_finds: + return numeric_finds[-1] + + return None + + +# units mainly from MathQA +unit_texts = [ + "east", + "degree", + "mph", + "kmph", + "ft", + "m sqaure", + " m east", + "sq m", + "deg", + "mile", + "q .", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "gm", + "p . m", + "lb", + "tile", + "per", + "dm", + "lt", + "gain", + "ab", + "way", + "west", + "a .", + "b .", + "c .", + "d .", + "e .", + "f .", + "g .", + "h .", + "t", + "a", + "h", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "percent", + "by", + "gal", + "kmh", + "c", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "l", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "feet", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "s", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "g", + "month", + "km", + "m", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "d", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", +] + +unit_texts.extend([t + "s" for t in unit_texts if not t.endswith("s")]) + + +def strip_string(input_str: Optional[str]) -> Optional[str]: + """Clean and normalize math answer strings.""" + if input_str is None: + return None + + string = str(input_str).strip() + + # Basic cleaning and formatting + string = string.replace("\n", "") + string = string.rstrip(".") + string = string.replace("\\!", "") + string = string.replace("\\$", "").replace("$", "") + string = string.replace("%", "").replace("\\%", "") + + # Normalization of LaTeX format + string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) + string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) + string = string.replace("bmatrix", "pmatrix") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("\\{", "{").replace("\\}", "}") + string = re.sub(r"\\mbox{.*?}", "", string) + string = string.replace("\\mathbf", "") + + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + + # Remove text and units + string = re.sub(r"\\text\{(.*?)\}", r"\1", string) + + for _ in range(2): + for unit_text in unit_texts: + _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) + if _string != "": + string = _string + + # Clean numerical values + try: + string = str(w2n.word_to_num(string)) + except ValueError: + pass + + string = re.sub(r"^[a-zA-Z]\s*=\s*", "", string) + + string = string.replace("infinity", "\\infty").replace("inf", "\\infty") + + string = re.sub(r"(\d+)\.0+([^\d]|$)", r"\1\2", string) + + string = string.replace(" .", " 0.").replace("{.", "{0.") + if string.startswith("."): + string = "0" + string + + # Fix the structure and final cleanup + string = string.replace(" ", "") + + string = fix_sqrt(string) + string = fix_fracs(string) + string = fix_a_slash_b(string) + + if (string.startswith("{") and string.endswith("}")) or ( + string.startswith("(") and string.endswith(")") + ): + string = string[1:-1] + + if not string: + return None + + return string.strip() + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + if "sqrt" not in a: + a = int(a) + if "sqrt" not in b: + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except Exception: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + pattern = r"\\sqrt\s*(\w+)" + replacement = r"\\sqrt{\1}" + + return re.sub(pattern, replacement, string) + + +def convert_word_number(text: str) -> str: + try: + text = str(w2n.word_to_num(text)) + except Exception: + pass + return text + + +def _compare_numerical(pred: str, ref: str) -> Optional[bool]: + """Helper for numerical comparison in math_equal.""" + try: + if _is_digit(pred) and _is_digit(ref): + pred_num = _parse_digits(pred) + ref_num = _parse_digits(ref) + # Allow for percentage conversions + possible_ref_values = {ref_num, ref_num / 100, ref_num * 100} + for val in possible_ref_values: + if numeric_equal(pred_num, val): + return True + return False + except (ValueError, TypeError): + pass + return None + + +def _compare_structures(pred: str, ref: str) -> Optional[bool]: + """Helper for structural comparison (intervals, matrices) in math_equal.""" + is_pred_interval = pred.startswith(("(", "[")) and pred.endswith((")", "]")) + is_ref_interval = ref.startswith(("(", "[")) and ref.endswith((")", "]")) + if is_pred_interval and is_ref_interval: + pred_parts = pred[1:-1].split(",") + ref_parts = ref[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all( + math_equal(p.strip(), r.strip()) for p, r in zip(pred_parts, ref_parts) + ): + return True + + is_pred_matrix = pred.startswith("\\begin{pmatrix}") + is_ref_matrix = ref.startswith("\\begin{pmatrix}") + if is_pred_matrix and is_ref_matrix: + pred_mat_str = pred[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")] + ref_mat_str = ref[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")] + pred_rows = [row.split("&") for row in pred_mat_str.split("\\\\")] + ref_rows = [row.split("&") for row in ref_mat_str.split("\\\\")] + if ( + len(pred_rows) == len(ref_rows) + and len(pred_rows[0]) == len(ref_rows[0]) + and all( + math_equal(p_elem.strip(), r_elem.strip()) + for p_row, r_row in zip(pred_rows, ref_rows) + for p_elem, r_elem in zip(p_row, r_row) + ) + ): + return True + return None + + +def _compare_equations(pred: str, ref: str) -> Optional[bool]: + """Helper for equation comparison in math_equal.""" + if pred.count("=") == 1 and ref.count("=") == 1: + pred_lhs, pred_rhs = (p.strip() for p in pred.split("=")) + ref_lhs, ref_rhs = (r.strip() for r in ref.split("=")) + if symbolic_equal(f"({pred_lhs})-({pred_rhs})", f"({ref_lhs})-({ref_rhs})"): + return True + + if pred.count("=") == 1 and ref.count("=") == 0: + var, val = (p.strip() for p in pred.split("=")) + if len(var) <= 2 and math_equal(val, ref): + return True + + if ref.count("=") == 1 and pred.count("=") == 0: + var, val = (r.strip() for r in ref.split("=")) + if len(var) <= 2 and math_equal(pred, val): + return True + return None + + +def math_equal(prediction: Optional[str], reference: Optional[str]) -> bool: + """Checks the mathematical equality of two strings by trying different methods.""" + if prediction is None or reference is None: + return False + if prediction == reference: + return True + + comparisons = [ + _compare_numerical, + _compare_structures, + _compare_equations, + ] + for func in comparisons: + result = func(prediction, reference) + if result is not None: + return result + + return symbolic_equal(prediction, reference) + + +def numeric_equal(prediction: float, reference: float) -> bool: + return isclose(reference, prediction, rel_tol=1e-4) + + +def _is_digit(s: str) -> bool: + try: + float(s.replace(",", "")) + return True + except (ValueError, TypeError): + return False + + +def _parse_digits(s: str) -> float: + return float(s.replace(",", "")) + + +def _parse_symbolic(s: str) -> Any: + """Parse a string into a sympy expression, trying different methods.""" + s_cleaned = s.replace("\\\\", "\\") + try: + # Use a local dict to handle functions like exp + return parse_latex(s_cleaned, locals={"exp": sympy.exp}) + except Exception: # Broad exception is okay here as we have fallbacks + try: + return sympify(s_cleaned, evaluate=True) + except (sympy.SympifyError, TypeError, SyntaxError): + return s # Return original string if all parsing fails + + +def _check_direct_or_simplified_equality(a_sym: Any, b_sym: Any) -> Optional[bool]: + """Check for direct or simplified symbolic equality.""" + try: + if a_sym == b_sym: + return True + except (TypeError, ValueError): + pass + try: + if simplify(a_sym - b_sym) == 0: + return True + except (AttributeError, TypeError, ValueError, sympy.SympifyError): + pass + return None + + +def _check_equation_equality(a_sym: Any, b_sym: Any) -> Optional[bool]: + """Check for symbolic equality of two equations.""" + try: + if isinstance(a_sym, sympy.Eq) and isinstance(b_sym, sympy.Eq): + if (a_sym.lhs - a_sym.rhs).equals(b_sym.lhs - b_sym.rhs): + return True + except (AttributeError, TypeError): + pass + return None + + +def _check_numeric_value_equality(a_sym: Any, b_sym: Any) -> Optional[bool]: + """Check for equality of the numerical values of two symbolic expressions.""" + try: + if numeric_equal(float(N(a_sym)), float(N(b_sym))): + return True + except (TypeError, ValueError, AttributeError): + pass + return None + + +def _check_matrix_equality(a_sym: Any, b_sym: Any) -> Optional[bool]: + """Check for symbolic equality of two matrices.""" + try: + if isinstance(a_sym, Matrix) and isinstance(b_sym, Matrix): + if a_sym.shape == b_sym.shape and simplify(a_sym - b_sym).is_zero_matrix: + return True + except (AttributeError, TypeError): + pass + return None + + +def symbolic_equal(a: str, b: str) -> bool: + """Compares two strings for symbolic equivalence using sympy.""" + a_sym = _parse_symbolic(a) + b_sym = _parse_symbolic(b) + + equality_checks = [ + _check_direct_or_simplified_equality, + _check_equation_equality, + _check_numeric_value_equality, + _check_matrix_equality, + ] + + for check_func in equality_checks: + result = check_func(a_sym, b_sym) + if result is True: + return True + + return False