Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ buffer:
prompt_key: 'question'
response_key: 'gt_answer'
rollout_args:
n: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
reward_fn_args:
reward_name: math_verify_reward
default_workflow_type: 'math_rm_workflow'
default_reward_fn_type: 'rm_gallery_reward'
trainer_input:
experience_buffer:
name: math_buffer
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ data = [
agent = [
"agentscope"
]
rm = [
"rm-gallery"
]
dev = [
"pre-commit>=2.17.0",
"black>=23.7.0",
Expand Down
2 changes: 2 additions & 0 deletions tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Test for the workflow module"""
import unittest
from dataclasses import dataclass
from typing import Dict, Optional
from unittest.mock import MagicMock

from tests.tools import get_unittest_dataset_config
Expand All @@ -13,6 +14,7 @@
class MockResponse:
response_text: str
reward: float = 0.0
metrics: Optional[Dict[str, float]] = None


class DummyWorkflow(Workflow):
Expand Down
1 change: 1 addition & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def read(
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
reward_fn_args=self.meta.reward_fn_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class StorageConfig:
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
reward_fn_args: dict = field(default_factory=dict)

# get storage from existing experiment
ray_namespace: Optional[str] = None
Expand Down
14 changes: 13 additions & 1 deletion trinity/common/rewards/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# -*- coding: utf-8 -*-
"""Reward functions for RFT"""

from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn
# isort: off
from .reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn

from .accuracy_reward import AccuracyReward
from .countdown_reward import CountDownRewardFn
from .format_reward import FormatReward
from .math_reward import MathBoxedRewardFn, MathRewardFn

# isort: on

__all__ = [
"RewardFn",
"RMGalleryFn",
"REWARD_FUNCTIONS",
"AccuracyReward",
"CountDownRewardFn",
"FormatReward",
"MathRewardFn",
"MathBoxedRewardFn",
]
83 changes: 59 additions & 24 deletions trinity/common/rewards/accuracy_reward.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,68 @@
from typing import Any, Callable, Dict, List
# -*- coding: utf-8 -*-
"""Accuracy Reward Function Class."""
from typing import Callable, Optional

from .base import RewardShapper
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

class AccuracyRewardShapper(RewardShapper):
"""Shapper for accuracy-based rewards"""
logger = get_logger(__name__)

def __init__(
self,
answer_parser: Callable[[str], str],
correct_reward: float = 1.0,
incorrect_reward: float = 0.0,
kwargs: Dict[str, Any] = {},
):

@REWARD_FUNCTIONS.register_module("accuracy_reward")
class AccuracyReward(RewardFn):
"""A reward function that rewards correct answers.
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
"""

def __init__(self, answer_parser: Optional[Callable[[str], str]] = None):
self.answer_parser = answer_parser
self.correct_reward = correct_reward
self.incorrect_reward = incorrect_reward
self.response_key = kwargs.get("response", "response")
self.truth_key = kwargs.get("ground_truth", "ground_truth")

def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
response = sample[self.response_key]
truth = sample[self.truth_key]
def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
if self.answer_parser:
answer_parsed = self.answer_parser(response)
truth_parsed = self.answer_parser(truth) # type: ignore [arg-type]

parsed_response = self.answer_parser(response)
reward = self.correct_reward if parsed_response == truth else self.incorrect_reward
else:
truth_parsed = parse(
truth,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(truth_parsed) == 0:
truth_parsed = truth

sample["accuracy_reward"] = reward
return sample
answer_parsed = parse(
response,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self.shape(sample) for sample in samples]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
try:
reward = float(verify(answer_parsed, truth_parsed))
except Exception as e:
logger.info(f"verify failed: {e}, answer: {answer_parsed}, gold: {truth_parsed}")
reward = 0.0
return {"accuracy": reward}
24 changes: 0 additions & 24 deletions trinity/common/rewards/base.py

This file was deleted.

24 changes: 0 additions & 24 deletions trinity/common/rewards/composite_reward.py

This file was deleted.

58 changes: 58 additions & 0 deletions trinity/common/rewards/countdown_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Base Reward Function Class."""
import json
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import (
evaluate_equation,
extract_solution,
validate_equation,
)
from trinity.utils.log import get_logger

logger = get_logger(__name__)


@REWARD_FUNCTIONS.register_module("countdown_reward")
class CountDownRewardFn(RewardFn):
"""A reward function that rewards for countdown task.
Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py
"""

def __init__(self):
pass

def __call__( # type: ignore
self,
response: str,
prompt: Optional[str] = None,
truth: Optional[str] = None,
) -> dict[str, float]:
truth = json.loads(truth) # type: ignore
target = truth["target"] # type: ignore
numbers = truth["numbers"] # type: ignore

solution_str = response
equation = extract_solution(solution_str=solution_str)
format_score = 0.1
score = 1.0

if equation is None:
return {"score": 0}

# Validate equation uses correct numbers
if not validate_equation(equation, numbers):
return {"score": format_score}

# Evaluate equation
try:
result = evaluate_equation(equation)
if result is None:
return {"score": format_score}

if abs(result - target) < 1e-5: # Account for floating point precision
return {"score": score}
else:
return {"score": format_score}
except Exception as e: # noqa: F841
return {"score": format_score}
43 changes: 21 additions & 22 deletions trinity/common/rewards/format_reward.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import re
from typing import Any, Dict, List
"""Base Reward Function Class."""

from .base import RewardShapper
import re
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.log import get_logger

class FormatRewardShapper(RewardShapper):
"""Shapper for format-based rewards"""
logger = get_logger(__name__)

def __init__(
self, pattern: str, correct_format_reward: float = 1.0, incorrect_format_reward: float = 0.0
):
self.pattern = re.compile(pattern, re.DOTALL | re.MULTILINE)
self.correct_format_reward = correct_format_reward
self.incorrect_format_reward = incorrect_format_reward

def shape(self, sample: Dict[str, Any]) -> Dict[str, Any]:
response = sample["response"]
reward = (
self.correct_format_reward
if self.pattern.match(response)
else self.incorrect_format_reward
)
@REWARD_FUNCTIONS.register_module("format_reward")
class FormatReward(RewardFn):
"""A reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.
Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
"""

sample["format_reward"] = reward
return sample
def __init__(self, pattern: Optional[str] = None):
self.pattern = pattern if pattern else r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"

def batch_shape(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self.shape(sample) for sample in samples]
def __call__( # type: ignore
self,
response,
) -> dict[str, float]:
if re.match(self.pattern, response, re.DOTALL | re.MULTILINE):
return {"format_score": 0.1}
else:
return {"format_score": -0.1}
Loading