diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index 5d3b16c2cc..b660c2eaaa 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -23,10 +23,12 @@ buffer: prompt_key: 'question' response_key: 'gt_answer' rollout_args: - n: 8 temperature: 1.0 logprobs: 0 - default_workflow_type: 'math_workflow' + reward_fn_args: + reward_name: math_verify_reward + default_workflow_type: 'math_rm_workflow' + default_reward_fn_type: 'rm_gallery_reward' trainer_input: experience_buffer: name: math_buffer diff --git a/pyproject.toml b/pyproject.toml index 654e54827c..61059b8c4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,9 @@ data = [ agent = [ "agentscope" ] +rm_gallery = [ + "rm-gallery" +] dev = [ "pre-commit>=2.17.0", "black>=23.7.0", diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index da4bc3b54d..944d9573a0 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -2,10 +2,17 @@ """Test for the workflow module""" import unittest from dataclasses import dataclass +from typing import Dict, Optional from unittest.mock import MagicMock from tests.tools import get_unittest_dataset_config -from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow +from trinity.common.rewards import RMGalleryFn +from trinity.common.workflows import ( + MathBoxedWorkflow, + MathRMWorkflow, + MathWorkflow, + Workflow, +) from trinity.common.workflows.workflow import Task @@ -13,6 +20,8 @@ class MockResponse: response_text: str reward: float = 0.0 + metrics: Optional[Dict[str, float]] = None + info: Optional[Dict] = None class DummyWorkflow(Workflow): @@ -206,7 +215,6 @@ def test_gsm8k_workflow(self) -> None: ) workflow = task.to_workflow(model=model) experiences = workflow.run() - # self.assertEqual(len(experiences), 1) self.assertEqual(experiences[0].reward, 1.1) self.assertEqual(experiences[1].reward, 0.9) self.assertEqual(experiences[2].reward, 0.9) @@ -229,6 +237,37 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[2].reward, -0.1) self.assertEqual(experiences[3].reward, 1.1) + @unittest.skip("Skip for now, need to fix import issues of RM-Gallery") + def test_rm_gallery_workflow(self) -> None: + model = MagicMock() + model.chat.return_value = [ + MockResponse(" balabalabala 99 \n \\boxed{36}"), + MockResponse("answer is \\boxed{36 }"), + MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"), + MockResponse(" balalaba \\boxed{35.00}"), + ] + taskset_config = get_unittest_dataset_config("countdown") + task = Task( + workflow=MathRMWorkflow, + reward_fn=RMGalleryFn, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + reward_fn_args={ + "reward_name": "math_verify_reward", + }, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"36", + }, + ) + workflow = task.to_workflow(model=model) + experiences = workflow.run() + self.assertEqual(experiences[0].reward, 1.0) + self.assertEqual(experiences[1].reward, 1.0) + self.assertEqual(experiences[2].reward, 1.0) + self.assertEqual(experiences[3].reward, 0.0) + def test_workflow_resettable(self) -> None: model = MagicMock() json_task = Task( diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 91ca4bc030..bc49b871d3 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -294,6 +294,7 @@ def read( format_args=self.meta.format, rollout_args=self.meta.rollout_args, workflow_args=self.meta.workflow_args, + reward_fn_args=self.meta.reward_fn_args, is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, raw_task=sample, diff --git a/trinity/common/config.py b/trinity/common/config.py index 87280cdbc8..1e0bcc5e9d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -99,6 +99,7 @@ class StorageConfig: default_reward_fn_type: Optional[str] = None rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) + reward_fn_args: dict = field(default_factory=dict) # get storage from existing experiment ray_namespace: Optional[str] = None diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index f20abc2748..05b752dfc6 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -1,11 +1,23 @@ # -*- coding: utf-8 -*- """Reward functions for RFT""" -from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn +# isort: off +from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn + +from .accuracy_reward import AccuracyReward +from .countdown_reward import CountDownRewardFn +from .format_reward import FormatReward +from .math_reward import MathBoxedRewardFn, MathRewardFn + +# isort: on __all__ = [ "RewardFn", + "RMGalleryFn", "REWARD_FUNCTIONS", "AccuracyReward", + "CountDownRewardFn", "FormatReward", + "MathRewardFn", + "MathBoxedRewardFn", ] diff --git a/trinity/common/rewards/accuracy_reward.py b/trinity/common/rewards/accuracy_reward.py index 76534b8973..98132030dd 100644 --- a/trinity/common/rewards/accuracy_reward.py +++ b/trinity/common/rewards/accuracy_reward.py @@ -1,33 +1,68 @@ -from typing import Any, Callable, Dict, List +# -*- coding: utf-8 -*- +"""Accuracy Reward Function Class.""" +from typing import Callable, Optional -from .base import RewardShapper +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.log import get_logger -class AccuracyRewardShapper(RewardShapper): - """Shapper for accuracy-based rewards""" +logger = get_logger(__name__) - def __init__( - self, - answer_parser: Callable[[str], str], - correct_reward: float = 1.0, - incorrect_reward: float = 0.0, - kwargs: Dict[str, Any] = {}, - ): + +@REWARD_FUNCTIONS.register_module("accuracy_reward") +class AccuracyReward(RewardFn): + """A reward function that rewards correct answers. + Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py + """ + + def __init__(self, answer_parser: Optional[Callable[[str], str]] = None): self.answer_parser = answer_parser - self.correct_reward = correct_reward - self.incorrect_reward = incorrect_reward - self.response_key = kwargs.get("response", "response") - self.truth_key = kwargs.get("ground_truth", "ground_truth") - def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]: - response = sample[self.response_key] - truth = sample[self.truth_key] + def __call__( # type: ignore + self, + response: str, + prompt: Optional[str] = None, + truth: Optional[str] = None, + ) -> dict[str, float]: + if self.answer_parser: + answer_parsed = self.answer_parser(response) + truth_parsed = self.answer_parser(truth) # type: ignore [arg-type] - parsed_response = self.answer_parser(response) - reward = self.correct_reward if parsed_response == truth else self.incorrect_reward + else: + truth_parsed = parse( + truth, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(truth_parsed) == 0: + truth_parsed = truth - sample["accuracy_reward"] = reward - return sample + answer_parsed = parse( + response, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) - def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - return [self.shape(sample) for sample in samples] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + try: + reward = float(verify(answer_parsed, truth_parsed)) + except Exception as e: + logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}") + reward = 0.0 + return {"accuracy": reward} diff --git a/trinity/common/rewards/base.py b/trinity/common/rewards/base.py deleted file mode 100644 index 00efe88d23..0000000000 --- a/trinity/common/rewards/base.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List - - -class RewardShapper(ABC): - """Abstract base class for reward shapper - - Supports: - 1. Rule-based shaping - 2. Model-based shaping - 3. Tool-based shaping - 4. Agent-based shaping - 5. Human-in-the-loop shaping - """ - - @abstractmethod - def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Shape a sample with rewards""" - pass - - @abstractmethod - def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Shape a batch of samples""" - pass diff --git a/trinity/common/rewards/composite_reward.py b/trinity/common/rewards/composite_reward.py deleted file mode 100644 index 0d2c8375e8..0000000000 --- a/trinity/common/rewards/composite_reward.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Any, Dict, List, Tuple - -from .base import RewardShapper - - -class CompositeRewardShapper(RewardShapper): - """Combines multiple shappers with weights""" - - def __init__(self, shappers: List[Tuple[RewardShapper, float]]): - self.shappers = shappers - - def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]: - total_reward = 0.0 - shapped_sample = sample.copy() - - for shapper, weight in self.shappers: - shapeged = shapper.shape(sample) - for key, value in shapeged.items(): - if key.endswith("_reward"): - shapped_sample[key] = value - total_reward += value * weight - - shapped_sample["total_reward"] = total_reward - return shapped_sample diff --git a/trinity/common/rewards/countdown_reward.py b/trinity/common/rewards/countdown_reward.py new file mode 100644 index 0000000000..07417431ad --- /dev/null +++ b/trinity/common/rewards/countdown_reward.py @@ -0,0 +1,58 @@ +"""Base Reward Function Class.""" +import json +from typing import Optional + +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.eval_utils import ( + evaluate_equation, + extract_solution, + validate_equation, +) +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@REWARD_FUNCTIONS.register_module("countdown_reward") +class CountDownRewardFn(RewardFn): + """A reward function that rewards for countdown task. + Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py + """ + + def __init__(self): + pass + + def __call__( # type: ignore + self, + response: str, + prompt: Optional[str] = None, + truth: Optional[str] = None, + ) -> dict[str, float]: + truth = json.loads(truth) # type: ignore + target = truth["target"] # type: ignore + numbers = truth["numbers"] # type: ignore + + solution_str = response + equation = extract_solution(solution_str=solution_str) + format_score = 0.1 + score = 1.0 + + if equation is None: + return {"score": 0} + + # Validate equation uses correct numbers + if not validate_equation(equation, numbers): + return {"score": format_score} + + # Evaluate equation + try: + result = evaluate_equation(equation) + if result is None: + return {"score": format_score} + + if abs(result - target) < 1e-5: # Account for floating point precision + return {"score": score} + else: + return {"score": format_score} + except Exception as e: # noqa: F841 + return {"score": format_score} diff --git a/trinity/common/rewards/format_reward.py b/trinity/common/rewards/format_reward.py index e464103f76..0ffe8637ec 100644 --- a/trinity/common/rewards/format_reward.py +++ b/trinity/common/rewards/format_reward.py @@ -1,29 +1,28 @@ -import re -from typing import Any, Dict, List +"""Base Reward Function Class.""" -from .base import RewardShapper +import re +from typing import Optional +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.log import get_logger -class FormatRewardShapper(RewardShapper): - """Shapper for format-based rewards""" +logger = get_logger(__name__) - def __init__( - self, pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0 - ): - self.pattern = re.compile(pattern, re.DOTALL | re.MULTILINE) - self.correct_format_reward = correct_format_reward - self.incorrect_format_reward = incorrect_format_reward - def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]: - response = sample["response"] - reward = ( - self.correct_format_reward - if self.pattern.match(response) - else self.incorrect_format_reward - ) +@REWARD_FUNCTIONS.register_module("format_reward") +class FormatReward(RewardFn): + """A reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags. + Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py + """ - sample["format_reward"] = reward - return sample + def __init__(self, pattern: Optional[str] = None): + self.pattern = pattern if pattern else r"^\n.*?\n\n\n.*?\n$" - def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - return [self.shape(sample) for sample in samples] + def __call__( # type: ignore + self, + response, + ) -> dict[str, float]: + if re.match(self.pattern, response, re.DOTALL | re.MULTILINE): + return {"format_score": 0.1} + else: + return {"format_score": -0.1} diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py new file mode 100644 index 0000000000..09a5cd7428 --- /dev/null +++ b/trinity/common/rewards/math_reward.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +"""Math Reward Function Class.""" +from typing import Optional + +from trinity.common.rewards.accuracy_reward import AccuracyReward +from trinity.common.rewards.format_reward import FormatReward +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.eval_utils import ( + compute_score, + simple_answer_parser, + validate_think_pattern, +) +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@REWARD_FUNCTIONS.register_module("math_reward") +class MathRewardFn(RewardFn): + """A reward function that rewards for math task.""" + + DEFAULT_FORMAT_PATTERN = r".*?.*?\s*.*?\s*$" + DEFAULT_ANSWER_PARSER = simple_answer_parser + + def __init__( + self, + answer_parser=DEFAULT_ANSWER_PARSER, + pattern=DEFAULT_FORMAT_PATTERN, + ) -> None: + self.accuracy_reward = AccuracyReward(answer_parser) + self.format_reward = FormatReward(pattern) + + def __call__( # type: ignore + self, + response: str, + prompt: Optional[str] = None, + truth: Optional[str] = None, + ) -> dict[str, float]: + accuracy_score = self.accuracy_reward(response, prompt, truth) + + format_score = self.format_reward(response) + + return {**accuracy_score, **format_score} + + +@REWARD_FUNCTIONS.register_module("math_boxed_reward") +class MathBoxedRewardFn(RewardFn): + """A reward function that rewards for math task.""" + + def __init__( + self, + ) -> None: + pass + + def __call__( # type: ignore + self, + response: str, + prompt: Optional[str] = None, + truth: Optional[str] = None, + with_think: Optional[bool] = False, + format_score_coef: Optional[float] = 0.1, + ) -> dict[str, float]: + accuracy_score = compute_score(response, truth) + + format_score = 0.0 + if with_think and not validate_think_pattern(response): + format_score = (format_score_coef or 0.1) * -1.0 + + return {"accuracy": accuracy_score, "format_score": format_score} diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index 822636482e..ee443ed085 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -1,21 +1,10 @@ # -*- coding: utf-8 -*- """Base Reward Function Class.""" -import json -import re from abc import ABC, abstractmethod -from typing import Callable, Optional, Union - -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify - -from trinity.utils.eval_utils import ( - compute_score, - evaluate_equation, - extract_solution, - simple_answer_parser, - validate_equation, - validate_think_pattern, -) +from typing import Any, Dict, List + +from trinity.common.experience import Experience +from trinity.common.rewards.utils import to_rm_gallery_messages from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -28,202 +17,91 @@ class RewardFn(ABC): """Base Reward Function Class.""" - # TODO: add a batch version - @abstractmethod - def __call__( - self, - response: str, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - ) -> Union[float, dict]: - """Call the reward function.""" - - -@REWARD_FUNCTIONS.register_module("accuracy_reward") -class AccuracyReward(RewardFn): - """A reward function that rewards correct answers. - Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py - """ - - def __init__(self, answer_parser: Optional[Callable[[str], str]] = None): - self.answer_parser = answer_parser - - def __call__( - self, - response: str, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - ): - if self.answer_parser: - answer_parsed = self.answer_parser(response) - truth_parsed = self.answer_parser(truth) # type: ignore [arg-type] - - else: - truth_parsed = parse( - truth, - extraction_mode="first_match", - extraction_config=[LatexExtractionConfig()], - ) - if len(truth_parsed) == 0: - truth_parsed = truth - - answer_parsed = parse( - response, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - equations=True, - boxed="all", - units=True, - ), - # Ensures that boxed is tried first - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) + def __init__(self, **kwargs) -> None: + pass - # Reward 1 if the content is the same as the ground truth, 0 otherwise - try: - reward = float(verify(answer_parsed, truth_parsed)) - except Exception as e: - logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}") - reward = 0.0 - return reward + @abstractmethod + def __call__(self, **kwargs) -> Dict[str, float]: + pass -@REWARD_FUNCTIONS.register_module("format_reward") -class FormatReward(RewardFn): - """A reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags. - Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py +@REWARD_FUNCTIONS.register_module("rm_gallery_reward") +class RMGalleryFn(RewardFn): + """Reward Function from RMGallery. + https://github.com/modelscope/RM-Gallery """ - def __init__(self, pattern: Optional[str] = None): - self.pattern = pattern if pattern else r"^\n.*?\n\n\n.*?\n$" - - def __call__( - self, - response, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - ) -> float: - if re.match(self.pattern, response, re.DOTALL | re.MULTILINE): - return 0.1 - else: - return -0.1 - - -@REWARD_FUNCTIONS.register_module("math_reward") -class MathRewardFn(RewardFn): - """A reward function that rewards for math task.""" - - # DEFAULT_FORMAT_PATTERN = r"^\n.*?\n\n\n.*?\n$" - DEFAULT_FORMAT_PATTERN = r".*?.*?\s*.*?\s*$" - DEFAULT_ANSWER_PARSER = simple_answer_parser - def __init__( self, - answer_parser=DEFAULT_ANSWER_PARSER, - pattern=DEFAULT_FORMAT_PATTERN, - ) -> None: - self.accuracy_reward = AccuracyReward(answer_parser) - self.format_reward = FormatReward(pattern) + reward_name, + **kwargs, + ): + from rm_gallery.core.reward.registry import RewardRegistry - def __call__( # type: ignore - self, - response: str, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - ) -> Union[float, dict]: - accuracy_score = self.accuracy_reward(response, prompt, truth) + self.reward_model = RewardRegistry.get(reward_name)(**kwargs) - format_score = self.format_reward(response, prompt, truth) + def __call__(self, experience: Experience, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, float]: # type: ignore + """Call the reward function.""" - if return_dict: - return {"accuracy": accuracy_score, "format_score": format_score} + sample = self._build_sample_from_experience(experience, messages, **kwargs) - return accuracy_score + format_score + sample_with_reward = self.reward_model.evaluate(sample, **kwargs) + return self._extract_reward(sample_with_reward) -@REWARD_FUNCTIONS.register_module("countdown_reward") -class CountDownRewardFn(RewardFn): - """A reward function that rewards for countdown task.""" + def _build_sample_from_experience( + self, experience: Experience, messages: List[Dict[str, Any]], **kwargs + ) -> Any: + """Convert experience to sample. + Ref: https://github.com/modelscope/RM-Gallery/blob/main/rm_gallery/core/data/schema.py + """ + from rm_gallery.core.data.schema import DataOutput, DataSample, Step - def __init__(self): - pass + output = [ + DataOutput( + answer=Step( + role="assistant", + content=str(experience.response_text), + label={"reference": kwargs.get("ground_truth", "")}, + ), + ) + ] + + sample = DataSample( + unique_id="0", # TODO: Generate unique ID + input=to_rm_gallery_messages(messages), + output=output, + metadata=experience.info, + ) + return sample + + def _extract_reward(self, sample: Any) -> Dict[str, float]: + """ + Extract reward from DataSample in rm-gallery + """ + reward_dict = {} - def __call__( - self, - response: str, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - ) -> float: - # Copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py - truth = json.loads(truth) # type: ignore - target = truth["target"] # type: ignore - numbers = truth["numbers"] # type: ignore - - solution_str = response - equation = extract_solution(solution_str=solution_str) - format_score = 0.1 - score = 1.0 - - if equation is None: - return 0 - - # Validate equation uses correct numbers - if not validate_equation(equation, numbers): - return format_score - - # Evaluate equation try: - result = evaluate_equation(equation) - if result is None: - return format_score - - if abs(result - target) < 1e-5: # Account for floating point precision - return score - else: - return format_score - except Exception as e: # noqa: F841 - return format_score - - -@REWARD_FUNCTIONS.register_module("math_boxed_reward") -class MathBoxedRewardFn(RewardFn): - """A reward function that rewards for math task.""" - - def __init__( - self, - ) -> None: - pass + reward_obj = sample.output[0].answer.reward + except Exception as e: + raise ValueError(f"No reward is found in sample: {e}") + + from rm_gallery.core.reward.schema import ( + RewardDimensionWithRank, + RewardDimensionWithScore, + ) + + if reward_obj.details: + for detail in reward_obj.details: + if isinstance(detail, RewardDimensionWithScore): + reward_dict[detail.name] = detail.score + elif isinstance(detail, RewardDimensionWithRank): + # TODO: support multi-ranked dimension + if detail: + top_ranked_item = detail[0] + reward_dict[top_ranked_item.name] = top_ranked_item.score + else: + reward_dict["reward"] = reward_obj.score - def __call__( # type: ignore - self, - response: str, - prompt: Optional[str] = None, - truth: Optional[str] = None, - return_dict: Optional[bool] = False, - with_think: Optional[bool] = False, - format_score_coef: Optional[float] = 0.1, - ) -> Union[float, dict]: - accuracy_score = compute_score(response, truth) - - format_score = 0.0 - if with_think and not validate_think_pattern(response): - format_score = (format_score_coef or 0.1) * -1.0 - - if return_dict: - return {"accuracy": accuracy_score, "format_score": format_score} - - return accuracy_score + format_score + return reward_dict diff --git a/trinity/common/rewards/utils.py b/trinity/common/rewards/utils.py new file mode 100644 index 0000000000..0b66e6700c --- /dev/null +++ b/trinity/common/rewards/utils.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List + + +def to_rm_gallery_messages(messages: List[Dict[str, Any]]) -> Any: + """ + Converts string list to structured ChatMessage list for debugging. + + Args: + messages: List of alternating user/assistant messages + + Returns: + List of structured ChatMessage objects + """ + from rm_gallery.core.model.message import ChatMessage, MessageRole + + role_map = { + "system": MessageRole.SYSTEM, + "user": MessageRole.USER, + "assistant": MessageRole.ASSISTANT, + } + + return [ChatMessage(role=role_map[msg["role"]], content=msg["content"]) for msg in messages] diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 9d54f108d0..496996a05d 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -4,6 +4,7 @@ from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow +from .math_rm_workflow import MathRMWorkflow from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow __all__ = [ @@ -16,4 +17,5 @@ "AlfworldWorkflow", "SciWorldWorkflow", "MathBoxedWorkflow", + "MathRMWorkflow", ] diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index d71a5d2fb1..1e825e4e54 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -5,7 +5,7 @@ from typing import List from trinity.common.experience import Experience -from trinity.common.rewards.reward_fn import MathBoxedRewardFn +from trinity.common.rewards.math_reward import MathBoxedRewardFn from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task from trinity.utils.log import get_logger @@ -73,20 +73,20 @@ def run(self) -> List[Experience]: responses = self.model.generate([prompt_text], **self.rollout_args) for response in responses: - reward = MathBoxedRewardFn()( # type: ignore [misc] + reward_dict = MathBoxedRewardFn()( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, - return_dict=self.is_eval, with_think=self.with_think, format_score_coef=self.format_score_coef, ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) - if isinstance(reward, dict): - if response.metrics is None: - response.metrics = {} - response.metrics.update(reward) - reward = sum(reward.values()) - response.reward = reward return responses diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py new file mode 100644 index 0000000000..45940fdfae --- /dev/null +++ b/trinity/common/workflows/math_rm_workflow.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""We include the math workflow with rm-gallery reward in this file.""" + +from typing import List, Optional + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@WORKFLOWS.register_module("math_rm_workflow") +class MathRMWorkflow(SimpleWorkflow): + """A workflow for math tasks as introduced in DeepSeek-R1.""" + + def __init__( + self, + model: ModelWrapper, + task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + self.reset(task) + super().__init__( + model=model, + task=task, + auxiliary_models=auxiliary_models, + ) + + def run(self) -> List[Experience]: + messages = self.format_messages() + + logger.debug("start chat") + responses = self.model.chat(messages, **self.rollout_args) + for response in responses: + reward_dict = self.reward_fn( # type: ignore + response, + messages, + ground_truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + + logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + return responses diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 2bd0038435..0a2483788b 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -13,7 +13,8 @@ from trinity.common.config import FormatConfig, GenerationConfig from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn +from trinity.common.rewards.math_reward import MathRewardFn +from trinity.common.rewards.reward_fn import RewardFn from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -31,6 +32,7 @@ class Task: format_args: FormatConfig = field(default_factory=FormatConfig) rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) + reward_fn_args: dict = field(default_factory=dict) is_eval: bool = False reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample @@ -178,6 +180,7 @@ def reset(self, task: Task): self.format_args = task.format_args self.system_prompt = task.format_args.system_prompt self.reply_prefix = task.format_args.reply_prefix + self.reward_fn_args = task.reward_fn_args self.raw_task = task.raw_task self.task_desc = task.task_desc @@ -185,7 +188,7 @@ def reset(self, task: Task): reward_fn = task.reward_fn if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn): - self.reward_fn: RewardFn = reward_fn() + self.reward_fn: RewardFn = reward_fn(**self.reward_fn_args) else: raise ValueError("`reward_fn` must be a subclass of `RewardFn`") # Rollout args @@ -209,20 +212,20 @@ def run(self) -> List[Experience]: logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for response in responses: - reward = self.reward_fn( # type: ignore [misc] + reward_dict = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, - return_dict=self.is_eval, ) + + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) - if isinstance(reward, dict): - if response.metrics is None: - response.metrics = {} - response.metrics.update(reward) - reward = sum(reward.values()) - response.reward = reward return responses