Skip to content

Commit 60d88f9

Browse files
committed
(dcy) add math_reward_model and its test file
1 parent ab5f6e7 commit 60d88f9

File tree

2 files changed

+203
-10
lines changed

2 files changed

+203
-10
lines changed
Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Tuple, Optional, List, Dict
22
from easydict import EasyDict
33
from torch.utils.tensorboard import SummaryWriter
4-
from transformers import AutoTokenizer
4+
from transformers import AutoTokenizer, AutoModel
5+
import torch
6+
import torch.nn.functional as F
57
import re
68

79
from ding.utils import REWARD_MODEL_REGISTRY
@@ -13,8 +15,8 @@ class MathRewardModel(BaseRewardModel):
1315
config = dict(
1416
# (str) The type of the reward model.
1517
type='math',
16-
# (str) The name of the tokenizer, usually the huggingface tokenizer name.
17-
tokenizer_name='Qwen/Qwen2.5-Math-PRM-7B',
18+
# (str) The name of the tokenizer and model
19+
model_name='Qwen/Qwen2.5-Math-PRM-7B',
1820
)
1921

2022
def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa
@@ -23,23 +25,127 @@ def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWri
2325
self.logger = logger
2426
self.tb_logger = tb_logger
2527

26-
def estimate(self, data: List[str]) -> List[Dict]:
28+
# 初始化tokenizer和model
29+
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, trust_remote_code=True)
30+
self.model = AutoModel.from_pretrained(
31+
self.cfg.model_name, device_map=self.device, torch_dtype=torch.bfloat16, trust_remote_code=True
32+
)
33+
self.model.eval()
34+
35+
def make_step_rewards(self, logits: torch.Tensor, token_masks: torch.Tensor) -> List[List[float]]:
36+
"""Calculate step-wise rewards from model outputs"""
37+
probabilities = F.softmax(logits, dim=-1)
38+
probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
39+
40+
all_scores_res = []
41+
for i in range(probabilities.size(0)):
42+
sample = probabilities[i] # seq_len, num_labels
43+
positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
44+
non_zero_elements_list = positive_probs.cpu().tolist()
45+
all_scores_res.append(non_zero_elements_list)
46+
return all_scores_res
47+
48+
def estimate(self, data: List[Dict]) -> List[Dict]:
2749
"""
50+
Overview:
51+
Estimate rewards for mathematical reasoning steps using Qwen2.5-Math-PRM-7B model.
2852
Arguments:
29-
- data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string.
30-
of the \
31-
form "1 + 1 = ?"
53+
- data (:obj:`List[Dict]`): List of dictionaries containing:
54+
- system (:obj:`str`): System prompt for the model
55+
- query (:obj:`str`): The mathematical query to be evaluated
56+
- response (:obj:`List[str]`): List of reasoning steps
3257
Returns:
33-
- reward (:obj:`List[Dict]`): The estimated reward.
58+
- reward (:obj:`List[Dict]`): List of dictionaries containing:
59+
- reward (:obj:`float`): Final reward (last step reward)
60+
- metadata (:obj:`Dict`): Additional information including:
61+
- query (:obj:`str`): Original query
62+
- step_rewards (:obj:`List[float]`): Rewards for each reasoning step
63+
- num_steps (:obj:`int`): Number of reasoning steps
64+
Shapes:
65+
- input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length
66+
- outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size
67+
- token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)`
68+
- step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps
69+
Examples:
70+
>>> data = [{
71+
>>> "system": "Please reason step by step...",
72+
>>> "query": "What is 1 + 1?",
73+
>>> "response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"]
74+
>>> }]
75+
>>> results = model.estimate(data)
76+
>>> print(results[0]["reward"]) # 1.0
77+
>>> print(results[0]["metadata"]["step_rewards"]) # [0.8, 0.9, 1.0]
3478
"""
35-
pass
79+
# 批量处理所有样本
80+
all_messages = []
81+
for item in data:
82+
messages = [
83+
{
84+
"role": "system",
85+
"content": item['system']
86+
},
87+
{
88+
"role": "user",
89+
"content": item['query']
90+
},
91+
{
92+
"role": "assistant",
93+
"content": "<extra_0>".join(item['response']) + "<extra_0>"
94+
},
95+
]
96+
all_messages.append(messages)
97+
98+
# 批量转换为模型输入格式
99+
conversation_strs = [
100+
self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
101+
for messages in all_messages
102+
]
103+
104+
# 批量编码输入
105+
input_ids = self.tokenizer(
106+
conversation_strs, return_tensors="pt", padding=True, truncation=True
107+
)["input_ids"].to(self.model.device)
108+
109+
# 批量获取模型输出
110+
with torch.no_grad():
111+
outputs = self.model(input_ids=input_ids)
112+
113+
# 计算每个样本的步骤奖励
114+
step_sep_id = self.tokenizer.encode("<extra_0>")[0]
115+
token_masks = (input_ids == step_sep_id)
116+
batch_rewards = self.make_step_rewards(outputs[0], token_masks)
117+
118+
# 构建详细的结果字典
119+
results = []
120+
for item, step_rewards in zip(data, batch_rewards):
121+
results.append(
122+
{
123+
"reward": step_rewards[-1] if step_rewards else 0.0, # 最后一步的奖励作为总体奖励
124+
"metadata": {
125+
"query": item['query'],
126+
"step_rewards": step_rewards, # 每个步骤的奖励
127+
"num_steps": len(item['response']),
128+
}
129+
}
130+
)
131+
132+
return results
36133

37-
# rule-based reward model does not need training, thus the following methods are empty
38134
def train(self):
135+
"""
136+
Training is not implemented for this reward model as it uses a pre-trained model
137+
"""
138+
self.logger.warning("Training is not implemented for this reward model")
39139
pass
40140

41141
def collect_data(self, data: list) -> None:
142+
"""
143+
Data collection is not needed for this reward model
144+
"""
42145
pass
43146

44147
def clear_data(self) -> None:
148+
"""
149+
Data clearing is not needed for this reward model
150+
"""
45151
pass
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
from easydict import EasyDict
3+
import torch
4+
from unittest.mock import MagicMock
5+
6+
from ding.reward_model import MathRewardModel
7+
8+
9+
@pytest.mark.envtest
10+
def test_math_reward_model():
11+
# Create configuration
12+
cfg = EasyDict(dict(
13+
type='math',
14+
model_name='Qwen/Qwen2.5-Math-PRM-7B',
15+
))
16+
17+
# Create mock logger and tb_logger
18+
logger = MagicMock()
19+
tb_logger = MagicMock()
20+
21+
# Initialize reward model
22+
model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger)
23+
24+
# Test case 1: Simple math problem
25+
data_simple = [
26+
{
27+
"system": "Please reason step by step...",
28+
"query": "What is 1 + 1?",
29+
"response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"]
30+
}
31+
]
32+
33+
# Test case 2: Complex word problem
34+
data_complex = [
35+
{
36+
"system": "Please reason step by step, and put your final answer within \\boxed{}.",
37+
"query": "Sue lives in a fun neighborhood...",
38+
"response": [
39+
"To find out how many more pink plastic flamingos...",
40+
"On Saturday, they take back one third of the flamingos...",
41+
"On Sunday, the neighbors add another 18 pink plastic flamingos...",
42+
"To find the difference, subtract the number of white flamingos..."
43+
]
44+
}
45+
]
46+
47+
# Test simple case
48+
results_simple = model.estimate(data_simple)
49+
50+
# Verify simple case results
51+
assert len(results_simple) == 1, "Should return one result"
52+
assert "reward" in results_simple[0], "Result should contain reward"
53+
assert "metadata" in results_simple[0], "Result should contain metadata"
54+
assert "step_rewards" in results_simple[0]["metadata"], "Metadata should contain step_rewards"
55+
assert len(results_simple[0]["metadata"]["step_rewards"]) == 3, "Should have 3 step rewards"
56+
assert results_simple[0]["metadata"]["num_steps"] == 3, "Should have 3 steps"
57+
58+
# Test complex case
59+
results_complex = model.estimate(data_complex)
60+
61+
# Verify complex case results
62+
assert len(results_complex) == 1, "Should return one result"
63+
assert "reward" in results_complex[0], "Result should contain reward"
64+
assert "metadata" in results_complex[0], "Result should contain metadata"
65+
assert "step_rewards" in results_complex[0]["metadata"], "Metadata should contain step_rewards"
66+
assert len(results_complex[0]["metadata"]["step_rewards"]) == 4, "Should have 4 step rewards"
67+
assert results_complex[0]["metadata"]["num_steps"] == 4, "Should have 4 steps"
68+
69+
# Verify reward value ranges
70+
for result in results_simple + results_complex:
71+
assert 0 <= result["reward"] <= 1, "Reward should be between 0 and 1"
72+
for step_reward in result["metadata"]["step_rewards"]:
73+
assert 0 <= step_reward <= 1, "Step rewards should be between 0 and 1"
74+
75+
# Test batch processing functionality
76+
batch_data = data_simple + data_complex
77+
batch_results = model.estimate(batch_data)
78+
assert len(batch_results) == 2, "Should return two results for batch processing"
79+
80+
# Print detailed information for debugging
81+
print("\nSimple problem results:")
82+
print(f"Final reward: {results_simple[0]['reward']}")
83+
print(f"Step rewards: {results_simple[0]['metadata']['step_rewards']}")
84+
85+
print("\nComplex problem results:")
86+
print(f"Final reward: {results_complex[0]['reward']}")
87+
print(f"Step rewards: {results_complex[0]['metadata']['step_rewards']}")

0 commit comments

Comments
 (0)