Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ buffer:
prompt_key: 'question'
response_key: 'gt_answer'
rollout_args:
n: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
reward_fn_args:
reward_name: math_verify_reward
default_workflow_type: 'math_rm_workflow'
default_reward_fn_type: 'rm_gallery_reward'
trainer_input:
experience_buffer:
name: math_buffer
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ data = [
agent = [
"agentscope"
]
rm_gallery = [
"rm-gallery"
]
dev = [
"pre-commit>=2.17.0",
"black>=23.7.0",
Expand Down
43 changes: 41 additions & 2 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@
"""Test for the workflow module"""
import unittest
from dataclasses import dataclass
from typing import Dict, Optional
from unittest.mock import MagicMock

from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathBoxedWorkflow, MathWorkflow, Workflow
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
MathBoxedWorkflow,
MathRMWorkflow,
MathWorkflow,
Workflow,
)
from trinity.common.workflows.workflow import Task


@dataclass
class MockResponse:
response_text: str
reward: float = 0.0
metrics: Optional[Dict[str, float]] = None
info: Optional[Dict] = None


class DummyWorkflow(Workflow):
Expand Down Expand Up @@ -206,7 +215,6 @@ def test_gsm8k_workflow(self) -> None:
)
workflow = task.to_workflow(model=model)
experiences = workflow.run()
# self.assertEqual(len(experiences), 1)
self.assertEqual(experiences[0].reward, 1.1)
self.assertEqual(experiences[1].reward, 0.9)
self.assertEqual(experiences[2].reward, 0.9)
Expand All @@ -229,6 +237,37 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

@unittest.skip("Skip for now, need to fix import issues of RM-Gallery")
def test_rm_gallery_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
MockResponse("<think> balabalabala 99 </think>\n \\boxed{36}"),
MockResponse("answer is \\boxed{36 }"),
MockResponse("Kim's total points are 6 + 30 =\\boxed{36}"),
MockResponse("<think> balalaba </think> \\boxed{35.00}"),
]
taskset_config = get_unittest_dataset_config("countdown")
task = Task(
workflow=MathRMWorkflow,
reward_fn=RMGalleryFn,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
reward_fn_args={
"reward_name": "math_verify_reward",
},
is_eval=False,
raw_task={
taskset_config.format.prompt_key: "",
taskset_config.format.response_key: r"36",
},
)
workflow = task.to_workflow(model=model)
experiences = workflow.run()
self.assertEqual(experiences[0].reward, 1.0)
self.assertEqual(experiences[1].reward, 1.0)
self.assertEqual(experiences[2].reward, 1.0)
self.assertEqual(experiences[3].reward, 0.0)

def test_workflow_resettable(self) -> None:
model = MagicMock()
json_task = Task(
Expand Down
1 change: 1 addition & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def read(
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
reward_fn_args=self.meta.reward_fn_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class StorageConfig:
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
reward_fn_args: dict = field(default_factory=dict)

# get storage from existing experiment
ray_namespace: Optional[str] = None
Expand Down
14 changes: 13 additions & 1 deletion trinity/common/rewards/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# -*- coding: utf-8 -*-
"""Reward functions for RFT"""

from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn
# isort: off
from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn

from .accuracy_reward import AccuracyReward
from .countdown_reward import CountDownRewardFn
from .format_reward import FormatReward
from .math_reward import MathBoxedRewardFn, MathRewardFn

# isort: on

__all__ = [
"RewardFn",
"RMGalleryFn",
"REWARD_FUNCTIONS",
"AccuracyReward",
"CountDownRewardFn",
"FormatReward",
"MathRewardFn",
"MathBoxedRewardFn",
]
83 changes: 59 additions & 24 deletions trinity/common/rewards/accuracy_reward.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,68 @@
from typing import Any, Callable, Dict, List
# -*- coding: utf-8 -*-
"""Accuracy Reward Function Class."""
from typing import Callable, Optional

from .base import RewardShapper
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

class AccuracyRewardShapper(RewardShapper):
"""Shapper for accuracy-based rewards"""
logger = get_logger(__name__)

def __init__(
self,
answer_parser: Callable[[str], str],
correct_reward: float = 1.0,
incorrect_reward: float = 0.0,
kwargs: Dict[str, Any] = {},
):

@REWARD_FUNCTIONS.register_module("accuracy_reward")
class AccuracyReward(RewardFn):
"""A reward function that rewards correct answers.
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
"""

def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
self.answer_parser = answer_parser
self.correct_reward = correct_reward
self.incorrect_reward = incorrect_reward
self.response_key = kwargs.get("response", "response")
self.truth_key = kwargs.get("ground_truth", "ground_truth")

def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
response = sample[self.response_key]
truth = sample[self.truth_key]
def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
if self.answer_parser:
answer_parsed = self.answer_parser(response)
truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]

parsed_response = self.answer_parser(response)
reward = self.correct_reward if parsed_response == truth else self.incorrect_reward
else:
truth_parsed = parse(
truth,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(truth_parsed) == 0:
truth_parsed = truth

sample["accuracy_reward"] = reward
return sample
answer_parsed = parse(
response,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self.shape(sample) for sample in samples]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
reward = float(verify(answer_parsed, truth_parsed))
except Exception as e:
logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
reward = 0.0
return {"accuracy": reward}
24 changes: 0 additions & 24 deletions trinity/common/rewards/base.py

This file was deleted.

24 changes: 0 additions & 24 deletions trinity/common/rewards/composite_reward.py

This file was deleted.

58 changes: 58 additions & 0 deletions trinity/common/rewards/countdown_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Base Reward Function Class."""
import json
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import (
evaluate_equation,
extract_solution,
validate_equation,
)
from trinity.utils.log import get_logger

logger = get_logger(__name__)


@REWARD_FUNCTIONS.register_module("countdown_reward")
class CountDownRewardFn(RewardFn):
"""A reward function that rewards for countdown task.
Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
"""

def __init__(self):
pass

def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
truth = json.loads(truth) # type: ignore
target = truth["target"] # type: ignore
numbers = truth["numbers"] # type: ignore

solution_str = response
equation = extract_solution(solution_str=solution_str)
format_score = 0.1
score = 1.0

if equation is None:
return {"score": 0}

# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
return {"score": format_score}

# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
return {"score": format_score}

if abs(result - target) < 1e-5: # Account for floating point precision
return {"score": score}
else:
return {"score": format_score}
except Exception as e: # noqa: F841
return {"score": format_score}
43 changes: 21 additions & 22 deletions trinity/common/rewards/format_reward.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import re
from typing import Any, Dict, List
"""Base Reward Function Class."""

from .base import RewardShapper
import re
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

class FormatRewardShapper(RewardShapper):
"""Shapper for format-based rewards"""
logger = get_logger(__name__)

def __init__(
self, pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0
):
self.pattern = re.compile(pattern, re.DOTALL | re.MULTILINE)
self.correct_format_reward = correct_format_reward
self.incorrect_format_reward = incorrect_format_reward

def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
response = sample["response"]
reward = (
self.correct_format_reward
if self.pattern.match(response)
else self.incorrect_format_reward
)
@REWARD_FUNCTIONS.register_module("format_reward")
class FormatReward(RewardFn):
"""A reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
"""

sample["format_reward"] = reward
return sample
def __init__(self, pattern: Optional[str] = None):
self.pattern = pattern if pattern else r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"

def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self.shape(sample) for sample in samples]
def __call__( # type: ignore
self,
response,
) -> dict[str, float]:
if re.match(self.pattern, response, re.DOTALL | re.MULTILINE):
return {"format_score": 0.1}
else:
return {"format_score": -0.1}
Loading