Skip to content

Commit 04a586b

Browse files
committed
feature(nyz): add basic math reward model interfaces
1 parent 6c2ca2f commit 04a586b

File tree

4 files changed

+190
-0
lines changed

4 files changed

+190
-0
lines changed

ding/reward_model/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
from .guided_cost_reward_model import GuidedCostRewardModel
1414
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
1515
from .icm_reward_model import ICMRewardModel
16+
# LLM/VLM reward model and verifier
17+
from .math_reward_model import MathRewardModel
18+
from .math_rule_reward_model import MathRuleRewardModel
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Tuple, Optional, List, Dict
2+
from easydict import EasyDict
3+
from torch.utils.tensorboard import SummaryWriter
4+
from transformers import AutoTokenizer
5+
import re
6+
7+
from ding.utils import REWARD_MODEL_REGISTRY
8+
from .base_reward_model import BaseRewardModel
9+
10+
11+
@REWARD_MODEL_REGISTRY.register('math')
12+
class MathRewardModel(BaseRewardModel):
13+
config = dict(
14+
# (str) The type of the reward model.
15+
type='math',
16+
# (str) The name of the tokenizer, usually the huggingface tokenizer name.
17+
tokenizer_name='Qwen/Qwen2.5-Math-PRM-7B',
18+
)
19+
20+
def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa
21+
self.cfg = config
22+
self.device = device
23+
self.logger = logger
24+
self.tb_logger = tb_logger
25+
26+
def estimate(self, data: List[str]) -> List[Dict]:
27+
"""
28+
Arguments:
29+
- data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string.
30+
of the \
31+
form "1 + 1 = ?"
32+
Returns:
33+
- reward (:obj:`List[Dict]`): The estimated reward.
34+
"""
35+
pass
36+
37+
# rule-based reward model does not need training, thus the following methods are empty
38+
def train(self):
39+
pass
40+
41+
def collect_data(self, data: list) -> None:
42+
pass
43+
44+
def clear_data(self) -> None:
45+
pass
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Tuple, Optional, List, Dict
2+
from easydict import EasyDict
3+
from torch.utils.tensorboard import SummaryWriter
4+
from transformers import AutoTokenizer
5+
import re
6+
7+
from ding.utils import REWARD_MODEL_REGISTRY
8+
from .base_reward_model import BaseRewardModel
9+
10+
11+
@REWARD_MODEL_REGISTRY.register('math_rule')
12+
class MathRuleRewardModel(BaseRewardModel):
13+
config = dict(
14+
# (str) The type of the reward model.
15+
type='math_rule',
16+
# (str) The name of the dataset, usually the huggingface dataset name.
17+
dataset_name='',
18+
# (str) The name of the tokenizer, usually the huggingface tokenizer name.
19+
tokenizer_name='',
20+
# (float) The score of format error.
21+
format_error_reward=-2,
22+
# (float) The score of answer error.
23+
answer_error_reward=-1,
24+
# (float) The score of correct.
25+
correct_reward=1,
26+
)
27+
28+
def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa
29+
self.cfg = config
30+
self.device = device
31+
self.logger = logger
32+
self.tb_logger = tb_logger
33+
34+
def estimate(self, data: List[str]) -> List[Dict]:
35+
"""
36+
Arguments:
37+
- data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string of the \
38+
form "1 + 1 = ?"
39+
Returns:
40+
- reward (:obj:`List[Dict]`): The estimated reward.
41+
"""
42+
# 1. parse the query to get question and predicted answer
43+
# 2. get the ground truth answer according to the question
44+
# 3. calculate the reward based on the predicted answer and the ground truth answer (format error -2, answer error -1, correct 1)
45+
pass
46+
47+
# rule-based reward model does not need training, thus the following methods are empty
48+
def train(self):
49+
pass
50+
51+
def collect_data(self, data: list) -> None:
52+
pass
53+
54+
def clear_data(self) -> None:
55+
pass
56+
57+
58+
def strip_sequence(text: str, pad_token: str, eos_token: str) -> str:
59+
"""
60+
Overview:
61+
Remove leading and trailing sequences of padding/eos tokens from a text.
62+
63+
.. note::
64+
This function uses regular expressions to strip all consecutive occurrences
65+
of the specified padding and end-of-sequence tokens from both the beginning
66+
and end of the input text. Tokens in the middle of the text are preserved.
67+
68+
Arguments:
69+
- text (str): The input text to be processed.
70+
- pad_token (str): The padding token to be stripped (e.g., "<PAD>").
71+
- eos_token (str): The end-of-sequence token to be stripped (e.g., "<EOS>").
72+
73+
Returns:
74+
- cleaned_text (str): The cleaned text with leading/trailing padding/eos tokens removed.
75+
76+
Examples:
77+
>>> strip_sequence("<PAD><EOS>Hello<EOS><PAD>", "<PAD>", "<EOS>")
78+
'Hello'
79+
80+
>>> strip_sequence("Test<EOS>Middle<PAD>Keep", "<PAD>", "<EOS>")
81+
'Test<EOS>Middle<PAD>Keep'
82+
83+
>>> strip_sequence("<EOS><EOS><PAD>Full removal<PAD><EOS>", "<PAD>", "<EOS>")
84+
'Full removal'
85+
86+
>>> strip_sequence("No tokens here", "<PAD>", "<EOS>")
87+
'No tokens here'
88+
89+
>>> strip_sequence("<PAD><PAD>", "<PAD>", "<EOS>")
90+
''
91+
"""
92+
pad_token_escaped = re.escape(pad_token)
93+
eos_token_escaped = re.escape(eos_token)
94+
95+
# Remove leading tokens
96+
pattern = f"^({eos_token_escaped}|{pad_token_escaped})+"
97+
text = re.sub(pattern, "", text)
98+
99+
# Remove trailing tokens
100+
pattern = f"({eos_token_escaped}|{pad_token_escaped})+$"
101+
text = re.sub(pattern, "", text)
102+
return text
103+
104+
105+
def normalize_text(text: str) -> str:
106+
"""
107+
Overview:
108+
This function is designed to standardize text by:
109+
- Converting all text to lowercase
110+
- Replacing various punctuation marks and special characters with spaces
111+
- Removing import statements
112+
- Normalizing whitespace by replacing multiple spaces with a single space
113+
- Stripping leading and trailing whitespace
114+
Arguments:
115+
- text (str): The input text to be processed.
116+
Returns:
117+
- normalized_text (str): The normalized text.
118+
"""
119+
text = re.sub("[,.:\"'\[\]\-=\+\\|!@#$%^&*();<>?/!¥…()—\{\}:”“《》?]", " ", text.lower())
120+
text = re.sub("import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text)
121+
text = re.sub("\s+", " ", text)
122+
return text.strip()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
from easydict import EasyDict
3+
4+
from ding.reward_model import MathRuleRewardModel
5+
6+
7+
@pytest.mark.envtest
8+
def test_math_rule_reward_model():
9+
reward_model = MathRuleRewardModel(
10+
config=EasyDict(
11+
dataset_name='RUC-AIBOX/STILL-3-Preview-RL-Data',
12+
tokenizer_name='unsloth/Meta-Llama-3.1-8B',
13+
)
14+
)
15+
16+
data = [
17+
"The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", # noqa
18+
]
19+
rewards = reward_model.estimate(data)
20+
assert len(rewards) == len(data)

0 commit comments

Comments
 (0)