Skip to content

Commit d8e9868

Browse files
committed
(dcy)polish flake8 add multimodal_rewardmodel and test
1 parent 7314bff commit d8e9868

File tree

7 files changed

+359
-75
lines changed

7 files changed

+359
-75
lines changed

ding/reward_model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .guided_cost_reward_model import GuidedCostRewardModel
1414
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
1515
from .icm_reward_model import ICMRewardModel
16-
# LLM/VLM reward model and verifier
16+
# LLM/VLM reward models and verifiers
1717
from .math_reward_model import MathRewardModel
1818
from .math_rule_reward_model import MathRuleRewardModel
19+
from .multi_modal_reward_model import MultiModalRewardModel

ding/reward_model/math_reward_model.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,21 @@ def estimate(self, data: List[Dict]) -> List[Dict]:
4949
Estimate rewards for mathematical reasoning steps using Qwen2.5-Math-PRM-7B model.
5050
Arguments:
5151
- data (:obj:`List[Dict]`): List of dictionaries containing:
52-
- system (:obj:`str`): System prompt for the model
53-
- query (:obj:`str`): The mathematical query to be evaluated
54-
- response (:obj:`List[str]`): List of reasoning steps
52+
- system (:obj:`str`): System prompt for the model.
53+
- query (:obj:`str`): The mathematical query to be evaluated.
54+
- response (:obj:`List[str]`): List of reasoning steps.
5555
Returns:
5656
- reward (:obj:`List[Dict]`): List of dictionaries containing:
57-
- reward (:obj:`float`): Final reward (last step reward)
57+
- reward (:obj:`float`): Final reward (last step reward).
5858
- metadata (:obj:`Dict`): Additional information including:
59-
- query (:obj:`str`): Original query
60-
- step_rewards (:obj:`List[float]`): Rewards for each reasoning step
61-
- num_steps (:obj:`int`): Number of reasoning steps
59+
- query (:obj:`str`): Original query.
60+
- step_rewards (:obj:`List[float]`): Rewards for each reasoning step.
61+
- num_steps (:obj:`int`): Number of reasoning steps.
6262
Shapes:
63-
- input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length
64-
- outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size
65-
- token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)`
66-
- step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps
63+
- input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length.
64+
- outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size.
65+
- token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)`.
66+
- step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps.
6767
Examples:
6868
>>> data = [{
6969
>>> "system": "Please reason step by step...",
@@ -74,7 +74,6 @@ def estimate(self, data: List[Dict]) -> List[Dict]:
7474
>>> print(results[0]["reward"]) # 1.0
7575
>>> print(results[0]["metadata"]["step_rewards"]) # [0.8, 0.9, 1.0]
7676
"""
77-
# 批量处理所有样本
7877
all_messages = []
7978
for item in data:
8079
messages = [
@@ -93,7 +92,6 @@ def estimate(self, data: List[Dict]) -> List[Dict]:
9392
]
9493
all_messages.append(messages)
9594

96-
# 批量转换为模型输入格式
9795
conversation_strs = [
9896
self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
9997
for messages in all_messages
@@ -104,24 +102,21 @@ def estimate(self, data: List[Dict]) -> List[Dict]:
104102
conversation_strs, return_tensors="pt", padding=True, truncation=True
105103
)["input_ids"].to(self.model.device)
106104

107-
# 批量获取模型输出
108105
with torch.no_grad():
109106
outputs = self.model(input_ids=input_ids)
110107

111-
# 计算每个样本的步骤奖励
112108
step_sep_id = self.tokenizer.encode("<extra_0>")[0]
113109
token_masks = (input_ids == step_sep_id)
114110
batch_rewards = self.make_step_rewards(outputs[0], token_masks)
115111

116-
# 构建详细的结果字典
117112
results = []
118113
for item, step_rewards in zip(data, batch_rewards):
119114
results.append(
120115
{
121-
"reward": step_rewards[-1] if step_rewards else 0.0, # 最后一步的奖励作为总体奖励
116+
"reward": step_rewards[-1] if step_rewards else 0.0,
122117
"metadata": {
123118
"query": item['query'],
124-
"step_rewards": step_rewards, # 每个步骤的奖励
119+
"step_rewards": step_rewards,
125120
"num_steps": len(item['response']),
126121
}
127122
}

ding/reward_model/math_rule_reward_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ def _extract_final_answer(self, text: str) -> Optional[str]:
135135
"""
136136
Extract the final answer from text.
137137
Supports various formats:
138-
1. "The answer is X"
139-
2. "Therefore, X is the answer"
140-
3. "X" (if only one number)
141-
4. "\\boxed{X}"
142-
5. "= X" (expression after equals sign)
138+
1. "The answer is X".
139+
2. "Therefore, X is the answer".
140+
3. "X" (if only one number).
141+
4. "\\boxed{X}".
142+
5. "= X" (expression after equals sign).
143143
6. Last LaTeX expression like \\frac{a}{b}, \\sqrt{x}, etc.
144144
"""
145145
# Try to extract boxed content
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from typing import List, Dict
2+
from easydict import EasyDict
3+
from torch.utils.tensorboard import SummaryWriter
4+
from transformers import AutoTokenizer, AutoModelForCausalLM
5+
import torch
6+
from ding.utils import REWARD_MODEL_REGISTRY
7+
from .base_reward_model import BaseRewardModel
8+
9+
10+
@REWARD_MODEL_REGISTRY.register('multi_modal')
11+
class MultiModalRewardModel(BaseRewardModel):
12+
config = dict(
13+
type='multi_modal',
14+
model_name='internlm/internlm-xcomposer2d5-7b-reward',
15+
hd_num=9, # Number of high-definition patches for image processing
16+
)
17+
18+
def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None:
19+
self.cfg = config
20+
self.device = device
21+
self.logger = logger
22+
self.tb_logger = tb_logger
23+
24+
self.tokenizer = AutoTokenizer.from_pretrained(
25+
self.cfg.model_name, trust_remote_code=True, local_files_only=True
26+
)
27+
self.model = AutoModelForCausalLM.from_pretrained(
28+
self.cfg.model_name, torch_dtype=torch.float16, trust_remote_code=True
29+
)
30+
31+
self.model.tokenizer = self.tokenizer
32+
self.model.cuda().eval()
33+
34+
def estimate(self, data: List[Dict], image: List[str], output_mode: str = 'score') -> List[Dict]:
35+
"""
36+
Estimate rewards for multi-modal inputs using internlm-xcomposer model.
37+
38+
Arguments:
39+
data (List[Dict]): List of chat dictionaries, each containing:
40+
- chat (List[Dict]): List of messages, each message is a dict with:
41+
- role (str): Either "user" or "assistant"
42+
- content (str): The message content
43+
image (List[str]): List of image paths. If fewer images than chats, last image will be reused
44+
output_mode (str, optional): Evaluation mode. Defaults to 'score'.
45+
- 'score': Return reward scores for each chat
46+
- 'rank': Return ranking indices (0 is best) for all chats
47+
- 'compare': Compare first two chats (returns 1.0 for better, 0.0 for worse)
48+
49+
Returns:
50+
List[Dict]: Results depending on output_mode:
51+
- For 'score' mode:
52+
[{
53+
'reward': float, # Reward score
54+
'metadata': {
55+
'mode': 'score',
56+
'chat_idx': int, # Index of the chat
57+
'image_path': str # Path of the image used
58+
}
59+
}, ...]
60+
- For 'rank' mode:
61+
[{
62+
'rank': int, # Ranking position (0 is best)
63+
'metadata': {
64+
'mode': 'rank',
65+
'chat_idx': int,
66+
'image_path': str
67+
}
68+
}, ...]
69+
- For 'compare' mode:
70+
[{
71+
'reward': float, # 1.0 for better, 0.0 for worse
72+
'metadata': {
73+
'mode': 'compare',
74+
'chat_idx': int,
75+
'image_path': str,
76+
'compared_with': int # Index of the compared chat
77+
}
78+
}, ...]
79+
"""
80+
# Get chat data
81+
chats = [item['chat'] for item in data]
82+
83+
with torch.autocast(device_type='cuda', dtype=torch.float16):
84+
if output_mode == 'score':
85+
# Ensure each chat has a corresponding image, use the last image if not enough
86+
if len(image) < len(chats):
87+
image = image + [image[-1]] * (len(chats) - len(image))
88+
89+
# Get scores for each chat
90+
scores = []
91+
for chat, img in zip(chats, image):
92+
score = self.model.get_score(chat, [img], hd_num=self.cfg.hd_num)
93+
scores.append(score)
94+
95+
return [
96+
{
97+
'reward': float(score),
98+
'metadata': {
99+
'mode': 'score',
100+
'chat_idx': idx,
101+
'image_path': img
102+
}
103+
} for idx, (score, img) in enumerate(zip(scores, image))
104+
]
105+
106+
elif output_mode == 'rank':
107+
# Use the first image for ranking
108+
img = image[0]
109+
ranks = self.model.rank(chats, [[img]] * len(chats), hd_num=self.cfg.hd_num)
110+
111+
return [
112+
{
113+
'rank': int(rank),
114+
'metadata': {
115+
'mode': 'rank',
116+
'chat_idx': idx,
117+
'image_path': img
118+
}
119+
} for idx, rank in enumerate(ranks)
120+
]
121+
122+
elif output_mode == 'compare':
123+
if len(data) < 2:
124+
raise ValueError("Compare mode requires at least 2 samples")
125+
126+
# Use the first image for comparison
127+
img = image[0]
128+
is_better = self.model.compare(chats[0], [img], chats[1], [img], hd_num=self.cfg.hd_num)
129+
130+
return [
131+
{
132+
'reward': 1.0 if is_better else 0.0,
133+
'metadata': {
134+
'mode': 'compare',
135+
'chat_idx': 0,
136+
'image_path': img,
137+
'compared_with': 1
138+
}
139+
}, {
140+
'reward': 0.0 if is_better else 1.0,
141+
'metadata': {
142+
'mode': 'compare',
143+
'chat_idx': 1,
144+
'image_path': img,
145+
'compared_with': 0
146+
}
147+
}
148+
]
149+
else:
150+
raise ValueError(f"Invalid output mode: {output_mode}")
151+
152+
def train(self):
153+
"""Training is not implemented for this reward model"""
154+
self.logger.warning("Training is not implemented for this reward model")
155+
pass
156+
157+
def collect_data(self, data: list) -> None:
158+
"""Data collection is not needed for this reward model"""
159+
pass
160+
161+
def clear_data(self) -> None:
162+
"""Data clearing is not needed for this reward model"""
163+
pass

ding/reward_model/tests/test_math_reward_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_math_reward_model():
2121
# Initialize reward model
2222
model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger)
2323

24-
# Test case 1: Simple math problem
24+
# Simple math problem
2525
data_simple = [
2626
{
2727
"system": "Please reason step by step...",
@@ -30,7 +30,7 @@ def test_math_reward_model():
3030
}
3131
]
3232

33-
# Test case 2: Complex word problem
33+
# Complex word problem
3434
data_complex = [
3535
{
3636
"system": "Please reason step by step, and put your final answer within \\boxed{}.",

0 commit comments

Comments
 (0)