diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml
index 5d3b16c2cc..b660c2eaaa 100644
--- a/examples/grpo_math/math.yaml
+++ b/examples/grpo_math/math.yaml
@@ -23,10 +23,12 @@ buffer:
prompt_key: 'question'
response_key: 'gt_answer'
rollout_args:
- n: 8
temperature: 1.0
logprobs: 0
- default_workflow_type: 'math_workflow'
+ reward_fn_args:
+ reward_name: math_verify_reward
+ default_workflow_type: 'math_rm_workflow'
+ default_reward_fn_type: 'rm_gallery_reward'
trainer_input:
experience_buffer:
name: math_buffer
diff --git a/pyproject.toml b/pyproject.toml
index 654e54827c..61059b8c4e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -53,6 +53,9 @@ data = [
agent = [
"agentscope"
]
+rm_gallery = [
+ "rm-gallery"
+]
dev = [
"pre-commit>=2.17.0",
"black>=23.7.0",
diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py
index da4bc3b54d..944d9573a0 100644
--- a/tests/explorer/workflow_test.py
+++ b/tests/explorer/workflow_test.py
@@ -2,10 +2,17 @@
"""Test for the workflow module"""
import unittest
from dataclasses import dataclass
+from typing import Dict, Optional
from unittest.mock import MagicMock
from tests.tools import get_unittest_dataset_config
-from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
+from trinity.common.rewards import RMGalleryFn
+from trinity.common.workflows import (
+ MathBoxedWorkflow,
+ MathRMWorkflow,
+ MathWorkflow,
+ Workflow,
+)
from trinity.common.workflows.workflow import Task
@@ -13,6 +20,8 @@
class MockResponse:
response_text: str
reward: float = 0.0
+ metrics: Optional[Dict[str, float]] = None
+ info: Optional[Dict] = None
class DummyWorkflow(Workflow):
@@ -206,7 +215,6 @@ def test_gsm8k_workflow(self) -> None:
)
workflow = task.to_workflow(model=model)
experiences = workflow.run()
- # self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 1.1)
self.assertEqual(experiences[1].reward, 0.9)
self.assertEqual(experiences[2].reward, 0.9)
@@ -229,6 +237,37 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)
+ @unittest.skip("Skip for now, need to fix import issues of RM-Gallery")
+ def test_rm_gallery_workflow(self) -> None:
+ model = MagicMock()
+ model.chat.return_value = [
+ MockResponse(" balabalabala 99 \n \\boxed{36}"),
+ MockResponse("answer is \\boxed{36 }"),
+ MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
+ MockResponse(" balalaba \\boxed{35.00}"),
+ ]
+ taskset_config = get_unittest_dataset_config("countdown")
+ task = Task(
+ workflow=MathRMWorkflow,
+ reward_fn=RMGalleryFn,
+ format_args=taskset_config.format,
+ rollout_args=taskset_config.rollout_args,
+ reward_fn_args={
+ "reward_name": "math_verify_reward",
+ },
+ is_eval=False,
+ raw_task={
+ taskset_config.format.prompt_key: "",
+ taskset_config.format.response_key: r"36",
+ },
+ )
+ workflow = task.to_workflow(model=model)
+ experiences = workflow.run()
+ self.assertEqual(experiences[0].reward, 1.0)
+ self.assertEqual(experiences[1].reward, 1.0)
+ self.assertEqual(experiences[2].reward, 1.0)
+ self.assertEqual(experiences[3].reward, 0.0)
+
def test_workflow_resettable(self) -> None:
model = MagicMock()
json_task = Task(
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 91ca4bc030..bc49b871d3 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -294,6 +294,7 @@ def read(
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
+ reward_fn_args=self.meta.reward_fn_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
diff --git a/trinity/common/config.py b/trinity/common/config.py
index 87280cdbc8..1e0bcc5e9d 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -99,6 +99,7 @@ class StorageConfig:
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
+ reward_fn_args: dict = field(default_factory=dict)
# get storage from existing experiment
ray_namespace: Optional[str] = None
diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py
index f20abc2748..05b752dfc6 100644
--- a/trinity/common/rewards/__init__.py
+++ b/trinity/common/rewards/__init__.py
@@ -1,11 +1,23 @@
# -*- coding: utf-8 -*-
"""Reward functions for RFT"""
-from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn
+# isort: off
+from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn
+
+from .accuracy_reward import AccuracyReward
+from .countdown_reward import CountDownRewardFn
+from .format_reward import FormatReward
+from .math_reward import MathBoxedRewardFn, MathRewardFn
+
+# isort: on
__all__ = [
"RewardFn",
+ "RMGalleryFn",
"REWARD_FUNCTIONS",
"AccuracyReward",
+ "CountDownRewardFn",
"FormatReward",
+ "MathRewardFn",
+ "MathBoxedRewardFn",
]
diff --git a/trinity/common/rewards/accuracy_reward.py b/trinity/common/rewards/accuracy_reward.py
index 76534b8973..98132030dd 100644
--- a/trinity/common/rewards/accuracy_reward.py
+++ b/trinity/common/rewards/accuracy_reward.py
@@ -1,33 +1,68 @@
-from typing import Any, Callable, Dict, List
+# -*- coding: utf-8 -*-
+"""Accuracy Reward Function Class."""
+from typing import Callable, Optional
-from .base import RewardShapper
+from latex2sympy2_extended import NormalizationConfig
+from math_verify import LatexExtractionConfig, parse, verify
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
+from trinity.utils.log import get_logger
-class AccuracyRewardShapper(RewardShapper):
- """Shapper for accuracy-based rewards"""
+logger = get_logger(__name__)
- def __init__(
- self,
- answer_parser: Callable[[str], str],
- correct_reward: float = 1.0,
- incorrect_reward: float = 0.0,
- kwargs: Dict[str, Any] = {},
- ):
+
+@REWARD_FUNCTIONS.register_module("accuracy_reward")
+class AccuracyReward(RewardFn):
+ """A reward function that rewards correct answers.
+ Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
+ """
+
+ def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
self.answer_parser = answer_parser
- self.correct_reward = correct_reward
- self.incorrect_reward = incorrect_reward
- self.response_key = kwargs.get("response", "response")
- self.truth_key = kwargs.get("ground_truth", "ground_truth")
- def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
- response = sample[self.response_key]
- truth = sample[self.truth_key]
+ def __call__( # type: ignore
+ self,
+ response: str,
+ prompt: Optional[str] = None,
+ truth: Optional[str] = None,
+ ) -> dict[str, float]:
+ if self.answer_parser:
+ answer_parsed = self.answer_parser(response)
+ truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]
- parsed_response = self.answer_parser(response)
- reward = self.correct_reward if parsed_response == truth else self.incorrect_reward
+ else:
+ truth_parsed = parse(
+ truth,
+ extraction_mode="first_match",
+ extraction_config=[LatexExtractionConfig()],
+ )
+ if len(truth_parsed) == 0:
+ truth_parsed = truth
- sample["accuracy_reward"] = reward
- return sample
+ answer_parsed = parse(
+ response,
+ extraction_config=[
+ LatexExtractionConfig(
+ normalization_config=NormalizationConfig(
+ nits=False,
+ malformed_operators=False,
+ basic_latex=True,
+ equations=True,
+ boxed="all",
+ units=True,
+ ),
+ # Ensures that boxed is tried first
+ boxed_match_priority=0,
+ try_extract_without_anchor=False,
+ )
+ ],
+ extraction_mode="first_match",
+ )
- def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- return [self.shape(sample) for sample in samples]
+ # Reward 1 if the content is the same as the ground truth, 0 otherwise
+ try:
+ reward = float(verify(answer_parsed, truth_parsed))
+ except Exception as e:
+ logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
+ reward = 0.0
+ return {"accuracy": reward}
diff --git a/trinity/common/rewards/base.py b/trinity/common/rewards/base.py
deleted file mode 100644
index 00efe88d23..0000000000
--- a/trinity/common/rewards/base.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, Dict, List
-
-
-class RewardShapper(ABC):
- """Abstract base class for reward shapper
-
- Supports:
- 1. Rule-based shaping
- 2. Model-based shaping
- 3. Tool-based shaping
- 4. Agent-based shaping
- 5. Human-in-the-loop shaping
- """
-
- @abstractmethod
- def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
- """Shape a sample with rewards"""
- pass
-
- @abstractmethod
- def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Shape a batch of samples"""
- pass
diff --git a/trinity/common/rewards/composite_reward.py b/trinity/common/rewards/composite_reward.py
deleted file mode 100644
index 0d2c8375e8..0000000000
--- a/trinity/common/rewards/composite_reward.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import Any, Dict, List, Tuple
-
-from .base import RewardShapper
-
-
-class CompositeRewardShapper(RewardShapper):
- """Combines multiple shappers with weights"""
-
- def __init__(self, shappers: List[Tuple[RewardShapper, float]]):
- self.shappers = shappers
-
- def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
- total_reward = 0.0
- shapped_sample = sample.copy()
-
- for shapper, weight in self.shappers:
- shapeged = shapper.shape(sample)
- for key, value in shapeged.items():
- if key.endswith("_reward"):
- shapped_sample[key] = value
- total_reward += value * weight
-
- shapped_sample["total_reward"] = total_reward
- return shapped_sample
diff --git a/trinity/common/rewards/countdown_reward.py b/trinity/common/rewards/countdown_reward.py
new file mode 100644
index 0000000000..07417431ad
--- /dev/null
+++ b/trinity/common/rewards/countdown_reward.py
@@ -0,0 +1,58 @@
+"""Base Reward Function Class."""
+import json
+from typing import Optional
+
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
+from trinity.utils.eval_utils import (
+ evaluate_equation,
+ extract_solution,
+ validate_equation,
+)
+from trinity.utils.log import get_logger
+
+logger = get_logger(__name__)
+
+
+@REWARD_FUNCTIONS.register_module("countdown_reward")
+class CountDownRewardFn(RewardFn):
+ """A reward function that rewards for countdown task.
+ Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__( # type: ignore
+ self,
+ response: str,
+ prompt: Optional[str] = None,
+ truth: Optional[str] = None,
+ ) -> dict[str, float]:
+ truth = json.loads(truth) # type: ignore
+ target = truth["target"] # type: ignore
+ numbers = truth["numbers"] # type: ignore
+
+ solution_str = response
+ equation = extract_solution(solution_str=solution_str)
+ format_score = 0.1
+ score = 1.0
+
+ if equation is None:
+ return {"score": 0}
+
+ # Validate equation uses correct numbers
+ if not validate_equation(equation, numbers):
+ return {"score": format_score}
+
+ # Evaluate equation
+ try:
+ result = evaluate_equation(equation)
+ if result is None:
+ return {"score": format_score}
+
+ if abs(result - target) < 1e-5: # Account for floating point precision
+ return {"score": score}
+ else:
+ return {"score": format_score}
+ except Exception as e: # noqa: F841
+ return {"score": format_score}
diff --git a/trinity/common/rewards/format_reward.py b/trinity/common/rewards/format_reward.py
index e464103f76..0ffe8637ec 100644
--- a/trinity/common/rewards/format_reward.py
+++ b/trinity/common/rewards/format_reward.py
@@ -1,29 +1,28 @@
-import re
-from typing import Any, Dict, List
+"""Base Reward Function Class."""
-from .base import RewardShapper
+import re
+from typing import Optional
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
+from trinity.utils.log import get_logger
-class FormatRewardShapper(RewardShapper):
- """Shapper for format-based rewards"""
+logger = get_logger(__name__)
- def __init__(
- self, pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0
- ):
- self.pattern = re.compile(pattern, re.DOTALL | re.MULTILINE)
- self.correct_format_reward = correct_format_reward
- self.incorrect_format_reward = incorrect_format_reward
- def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
- response = sample["response"]
- reward = (
- self.correct_format_reward
- if self.pattern.match(response)
- else self.incorrect_format_reward
- )
+@REWARD_FUNCTIONS.register_module("format_reward")
+class FormatReward(RewardFn):
+ """A reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags.
+ Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
+ """
- sample["format_reward"] = reward
- return sample
+ def __init__(self, pattern: Optional[str] = None):
+ self.pattern = pattern if pattern else r"^\n.*?\n\n\n.*?\n$"
- def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- return [self.shape(sample) for sample in samples]
+ def __call__( # type: ignore
+ self,
+ response,
+ ) -> dict[str, float]:
+ if re.match(self.pattern, response, re.DOTALL | re.MULTILINE):
+ return {"format_score": 0.1}
+ else:
+ return {"format_score": -0.1}
diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py
new file mode 100644
index 0000000000..09a5cd7428
--- /dev/null
+++ b/trinity/common/rewards/math_reward.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+"""Math Reward Function Class."""
+from typing import Optional
+
+from trinity.common.rewards.accuracy_reward import AccuracyReward
+from trinity.common.rewards.format_reward import FormatReward
+from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
+from trinity.utils.eval_utils import (
+ compute_score,
+ simple_answer_parser,
+ validate_think_pattern,
+)
+from trinity.utils.log import get_logger
+
+logger = get_logger(__name__)
+
+
+@REWARD_FUNCTIONS.register_module("math_reward")
+class MathRewardFn(RewardFn):
+ """A reward function that rewards for math task."""
+
+ DEFAULT_FORMAT_PATTERN = r".*?.*?\s*.*?\s*$"
+ DEFAULT_ANSWER_PARSER = simple_answer_parser
+
+ def __init__(
+ self,
+ answer_parser=DEFAULT_ANSWER_PARSER,
+ pattern=DEFAULT_FORMAT_PATTERN,
+ ) -> None:
+ self.accuracy_reward = AccuracyReward(answer_parser)
+ self.format_reward = FormatReward(pattern)
+
+ def __call__( # type: ignore
+ self,
+ response: str,
+ prompt: Optional[str] = None,
+ truth: Optional[str] = None,
+ ) -> dict[str, float]:
+ accuracy_score = self.accuracy_reward(response, prompt, truth)
+
+ format_score = self.format_reward(response)
+
+ return {**accuracy_score, **format_score}
+
+
+@REWARD_FUNCTIONS.register_module("math_boxed_reward")
+class MathBoxedRewardFn(RewardFn):
+ """A reward function that rewards for math task."""
+
+ def __init__(
+ self,
+ ) -> None:
+ pass
+
+ def __call__( # type: ignore
+ self,
+ response: str,
+ prompt: Optional[str] = None,
+ truth: Optional[str] = None,
+ with_think: Optional[bool] = False,
+ format_score_coef: Optional[float] = 0.1,
+ ) -> dict[str, float]:
+ accuracy_score = compute_score(response, truth)
+
+ format_score = 0.0
+ if with_think and not validate_think_pattern(response):
+ format_score = (format_score_coef or 0.1) * -1.0
+
+ return {"accuracy": accuracy_score, "format_score": format_score}
diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py
index 822636482e..ee443ed085 100644
--- a/trinity/common/rewards/reward_fn.py
+++ b/trinity/common/rewards/reward_fn.py
@@ -1,21 +1,10 @@
# -*- coding: utf-8 -*-
"""Base Reward Function Class."""
-import json
-import re
from abc import ABC, abstractmethod
-from typing import Callable, Optional, Union
-
-from latex2sympy2_extended import NormalizationConfig
-from math_verify import LatexExtractionConfig, parse, verify
-
-from trinity.utils.eval_utils import (
- compute_score,
- evaluate_equation,
- extract_solution,
- simple_answer_parser,
- validate_equation,
- validate_think_pattern,
-)
+from typing import Any, Dict, List
+
+from trinity.common.experience import Experience
+from trinity.common.rewards.utils import to_rm_gallery_messages
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
@@ -28,202 +17,91 @@
class RewardFn(ABC):
"""Base Reward Function Class."""
- # TODO: add a batch version
-
@abstractmethod
- def __call__(
- self,
- response: str,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- ) -> Union[float, dict]:
- """Call the reward function."""
-
-
-@REWARD_FUNCTIONS.register_module("accuracy_reward")
-class AccuracyReward(RewardFn):
- """A reward function that rewards correct answers.
- Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
- """
-
- def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
- self.answer_parser = answer_parser
-
- def __call__(
- self,
- response: str,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- ):
- if self.answer_parser:
- answer_parsed = self.answer_parser(response)
- truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]
-
- else:
- truth_parsed = parse(
- truth,
- extraction_mode="first_match",
- extraction_config=[LatexExtractionConfig()],
- )
- if len(truth_parsed) == 0:
- truth_parsed = truth
-
- answer_parsed = parse(
- response,
- extraction_config=[
- LatexExtractionConfig(
- normalization_config=NormalizationConfig(
- nits=False,
- malformed_operators=False,
- basic_latex=True,
- equations=True,
- boxed="all",
- units=True,
- ),
- # Ensures that boxed is tried first
- boxed_match_priority=0,
- try_extract_without_anchor=False,
- )
- ],
- extraction_mode="first_match",
- )
+ def __init__(self, **kwargs) -> None:
+ pass
- # Reward 1 if the content is the same as the ground truth, 0 otherwise
- try:
- reward = float(verify(answer_parsed, truth_parsed))
- except Exception as e:
- logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
- reward = 0.0
- return reward
+ @abstractmethod
+ def __call__(self, **kwargs) -> Dict[str, float]:
+ pass
-@REWARD_FUNCTIONS.register_module("format_reward")
-class FormatReward(RewardFn):
- """A reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags.
- Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
+@REWARD_FUNCTIONS.register_module("rm_gallery_reward")
+class RMGalleryFn(RewardFn):
+ """Reward Function from RMGallery.
+ https://github.com/modelscope/RM-Gallery
"""
- def __init__(self, pattern: Optional[str] = None):
- self.pattern = pattern if pattern else r"^\n.*?\n\n\n.*?\n$"
-
- def __call__(
- self,
- response,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- ) -> float:
- if re.match(self.pattern, response, re.DOTALL | re.MULTILINE):
- return 0.1
- else:
- return -0.1
-
-
-@REWARD_FUNCTIONS.register_module("math_reward")
-class MathRewardFn(RewardFn):
- """A reward function that rewards for math task."""
-
- # DEFAULT_FORMAT_PATTERN = r"^\n.*?\n\n\n.*?\n$"
- DEFAULT_FORMAT_PATTERN = r".*?.*?\s*.*?\s*$"
- DEFAULT_ANSWER_PARSER = simple_answer_parser
-
def __init__(
self,
- answer_parser=DEFAULT_ANSWER_PARSER,
- pattern=DEFAULT_FORMAT_PATTERN,
- ) -> None:
- self.accuracy_reward = AccuracyReward(answer_parser)
- self.format_reward = FormatReward(pattern)
+ reward_name,
+ **kwargs,
+ ):
+ from rm_gallery.core.reward.registry import RewardRegistry
- def __call__( # type: ignore
- self,
- response: str,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- ) -> Union[float, dict]:
- accuracy_score = self.accuracy_reward(response, prompt, truth)
+ self.reward_model = RewardRegistry.get(reward_name)(**kwargs)
- format_score = self.format_reward(response, prompt, truth)
+ def __call__(self, experience: Experience, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, float]: # type: ignore
+ """Call the reward function."""
- if return_dict:
- return {"accuracy": accuracy_score, "format_score": format_score}
+ sample = self._build_sample_from_experience(experience, messages, **kwargs)
- return accuracy_score + format_score
+ sample_with_reward = self.reward_model.evaluate(sample, **kwargs)
+ return self._extract_reward(sample_with_reward)
-@REWARD_FUNCTIONS.register_module("countdown_reward")
-class CountDownRewardFn(RewardFn):
- """A reward function that rewards for countdown task."""
+ def _build_sample_from_experience(
+ self, experience: Experience, messages: List[Dict[str, Any]], **kwargs
+ ) -> Any:
+ """Convert experience to sample.
+ Ref: https://github.com/modelscope/RM-Gallery/blob/main/rm_gallery/core/data/schema.py
+ """
+ from rm_gallery.core.data.schema import DataOutput, DataSample, Step
- def __init__(self):
- pass
+ output = [
+ DataOutput(
+ answer=Step(
+ role="assistant",
+ content=str(experience.response_text),
+ label={"reference": kwargs.get("ground_truth", "")},
+ ),
+ )
+ ]
+
+ sample = DataSample(
+ unique_id="0", # TODO: Generate unique ID
+ input=to_rm_gallery_messages(messages),
+ output=output,
+ metadata=experience.info,
+ )
+ return sample
+
+ def _extract_reward(self, sample: Any) -> Dict[str, float]:
+ """
+ Extract reward from DataSample in rm-gallery
+ """
+ reward_dict = {}
- def __call__(
- self,
- response: str,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- ) -> float:
- # Copy from Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
- truth = json.loads(truth) # type: ignore
- target = truth["target"] # type: ignore
- numbers = truth["numbers"] # type: ignore
-
- solution_str = response
- equation = extract_solution(solution_str=solution_str)
- format_score = 0.1
- score = 1.0
-
- if equation is None:
- return 0
-
- # Validate equation uses correct numbers
- if not validate_equation(equation, numbers):
- return format_score
-
- # Evaluate equation
try:
- result = evaluate_equation(equation)
- if result is None:
- return format_score
-
- if abs(result - target) < 1e-5: # Account for floating point precision
- return score
- else:
- return format_score
- except Exception as e: # noqa: F841
- return format_score
-
-
-@REWARD_FUNCTIONS.register_module("math_boxed_reward")
-class MathBoxedRewardFn(RewardFn):
- """A reward function that rewards for math task."""
-
- def __init__(
- self,
- ) -> None:
- pass
+ reward_obj = sample.output[0].answer.reward
+ except Exception as e:
+ raise ValueError(f"No reward is found in sample: {e}")
+
+ from rm_gallery.core.reward.schema import (
+ RewardDimensionWithRank,
+ RewardDimensionWithScore,
+ )
+
+ if reward_obj.details:
+ for detail in reward_obj.details:
+ if isinstance(detail, RewardDimensionWithScore):
+ reward_dict[detail.name] = detail.score
+ elif isinstance(detail, RewardDimensionWithRank):
+ # TODO: support multi-ranked dimension
+ if detail:
+ top_ranked_item = detail[0]
+ reward_dict[top_ranked_item.name] = top_ranked_item.score
+ else:
+ reward_dict["reward"] = reward_obj.score
- def __call__( # type: ignore
- self,
- response: str,
- prompt: Optional[str] = None,
- truth: Optional[str] = None,
- return_dict: Optional[bool] = False,
- with_think: Optional[bool] = False,
- format_score_coef: Optional[float] = 0.1,
- ) -> Union[float, dict]:
- accuracy_score = compute_score(response, truth)
-
- format_score = 0.0
- if with_think and not validate_think_pattern(response):
- format_score = (format_score_coef or 0.1) * -1.0
-
- if return_dict:
- return {"accuracy": accuracy_score, "format_score": format_score}
-
- return accuracy_score + format_score
+ return reward_dict
diff --git a/trinity/common/rewards/utils.py b/trinity/common/rewards/utils.py
new file mode 100644
index 0000000000..0b66e6700c
--- /dev/null
+++ b/trinity/common/rewards/utils.py
@@ -0,0 +1,22 @@
+from typing import Any, Dict, List
+
+
+def to_rm_gallery_messages(messages: List[Dict[str, Any]]) -> Any:
+ """
+ Converts string list to structured ChatMessage list for debugging.
+
+ Args:
+ messages: List of alternating user/assistant messages
+
+ Returns:
+ List of structured ChatMessage objects
+ """
+ from rm_gallery.core.model.message import ChatMessage, MessageRole
+
+ role_map = {
+ "system": MessageRole.SYSTEM,
+ "user": MessageRole.USER,
+ "assistant": MessageRole.ASSISTANT,
+ }
+
+ return [ChatMessage(role=role_map[msg["role"]], content=msg["content"]) for msg in messages]
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index 9d54f108d0..496996a05d 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -4,6 +4,7 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
+from .math_rm_workflow import MathRMWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
__all__ = [
@@ -16,4 +17,5 @@
"AlfworldWorkflow",
"SciWorldWorkflow",
"MathBoxedWorkflow",
+ "MathRMWorkflow",
]
diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py
index d71a5d2fb1..1e825e4e54 100644
--- a/trinity/common/workflows/customized_math_workflows.py
+++ b/trinity/common/workflows/customized_math_workflows.py
@@ -5,7 +5,7 @@
from typing import List
from trinity.common.experience import Experience
-from trinity.common.rewards.reward_fn import MathBoxedRewardFn
+from trinity.common.rewards.math_reward import MathBoxedRewardFn
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger
@@ -73,20 +73,20 @@ def run(self) -> List[Experience]:
responses = self.model.generate([prompt_text], **self.rollout_args)
for response in responses:
- reward = MathBoxedRewardFn()( # type: ignore [misc]
+ reward_dict = MathBoxedRewardFn()( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
- return_dict=self.is_eval,
with_think=self.with_think,
format_score_coef=self.format_score_coef,
)
+
+ if response.metrics is None:
+ response.metrics = {}
+ response.metrics.update(reward_dict)
+ reward = sum(reward_dict.values())
+ response.reward = reward
+
logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
- if isinstance(reward, dict):
- if response.metrics is None:
- response.metrics = {}
- response.metrics.update(reward)
- reward = sum(reward.values())
- response.reward = reward
return responses
diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py
new file mode 100644
index 0000000000..45940fdfae
--- /dev/null
+++ b/trinity/common/workflows/math_rm_workflow.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+"""We include the math workflow with rm-gallery reward in this file."""
+
+from typing import List, Optional
+
+import openai
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
+from trinity.utils.log import get_logger
+
+logger = get_logger(__name__)
+
+
+@WORKFLOWS.register_module("math_rm_workflow")
+class MathRMWorkflow(SimpleWorkflow):
+ """A workflow for math tasks as introduced in DeepSeek-R1."""
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List[openai.OpenAI]] = None,
+ ):
+ self.reset(task)
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+
+ def run(self) -> List[Experience]:
+ messages = self.format_messages()
+
+ logger.debug("start chat")
+ responses = self.model.chat(messages, **self.rollout_args)
+ for response in responses:
+ reward_dict = self.reward_fn( # type: ignore
+ response,
+ messages,
+ ground_truth=self.truth,
+ )
+
+ if response.metrics is None:
+ response.metrics = {}
+ response.metrics.update(reward_dict)
+ reward = sum(reward_dict.values())
+ response.reward = reward
+
+ logger.debug(
+ f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
+ )
+ return responses
diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py
index 2bd0038435..0a2483788b 100644
--- a/trinity/common/workflows/workflow.py
+++ b/trinity/common/workflows/workflow.py
@@ -13,7 +13,8 @@
from trinity.common.config import FormatConfig, GenerationConfig
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
-from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
+from trinity.common.rewards.math_reward import MathRewardFn
+from trinity.common.rewards.reward_fn import RewardFn
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry
@@ -31,6 +32,7 @@ class Task:
format_args: FormatConfig = field(default_factory=FormatConfig)
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
+ reward_fn_args: dict = field(default_factory=dict)
is_eval: bool = False
reward_fn: Optional[Type[RewardFn]] = None
raw_task: Optional[dict] = None # The raw data sample
@@ -178,6 +180,7 @@ def reset(self, task: Task):
self.format_args = task.format_args
self.system_prompt = task.format_args.system_prompt
self.reply_prefix = task.format_args.reply_prefix
+ self.reward_fn_args = task.reward_fn_args
self.raw_task = task.raw_task
self.task_desc = task.task_desc
@@ -185,7 +188,7 @@ def reset(self, task: Task):
reward_fn = task.reward_fn
if isinstance(reward_fn, type) and issubclass(reward_fn, RewardFn):
- self.reward_fn: RewardFn = reward_fn()
+ self.reward_fn: RewardFn = reward_fn(**self.reward_fn_args)
else:
raise ValueError("`reward_fn` must be a subclass of `RewardFn`")
# Rollout args
@@ -209,20 +212,20 @@ def run(self) -> List[Experience]:
logger.debug("start chat")
responses = self.model.chat(messages, **self.rollout_args)
for response in responses:
- reward = self.reward_fn( # type: ignore [misc]
+ reward_dict = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
- return_dict=self.is_eval,
)
+
+ if response.metrics is None:
+ response.metrics = {}
+ response.metrics.update(reward_dict)
+ reward = sum(reward_dict.values())
+ response.reward = reward
+
logger.debug(
f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"
)
- if isinstance(reward, dict):
- if response.metrics is None:
- response.metrics = {}
- response.metrics.update(reward)
- reward = sum(reward.values())
- response.reward = reward
return responses