Skip to content

Commit ea83f44

Browse files
author
root
committed
(dcy) add multimodal_rewardmodel
1 parent 7314bff commit ea83f44

File tree

7 files changed

+340
-57
lines changed

7 files changed

+340
-57
lines changed

ding/reward_model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# LLM/VLM reward model and verifier
1717
from .math_reward_model import MathRewardModel
1818
from .math_rule_reward_model import MathRuleRewardModel
19+
from .multi_modal_reward_model import MultiModalRewardModel

ding/reward_model/math_reward_model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
@REWARD_MODEL_REGISTRY.register('math')
1212
class MathRewardModel(BaseRewardModel):
1313
config = dict(
14-
# (str) The type of the reward model.
14+
# The type of the reward model.
1515
type='math',
16-
# (str) The name of the tokenizer and model
16+
# The name of the tokenizer and model
1717
model_name='Qwen/Qwen2.5-Math-PRM-7B',
1818
)
1919

@@ -23,7 +23,6 @@ def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWri
2323
self.logger = logger
2424
self.tb_logger = tb_logger
2525

26-
# 初始化tokenizer和model
2726
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, trust_remote_code=True)
2827
self.model = AutoModel.from_pretrained(
2928
self.cfg.model_name, device_map=self.device, torch_dtype=torch.bfloat16, trust_remote_code=True
@@ -99,16 +98,13 @@ def estimate(self, data: List[Dict]) -> List[Dict]:
9998
for messages in all_messages
10099
]
101100

102-
# 批量编码输入
103101
input_ids = self.tokenizer(
104102
conversation_strs, return_tensors="pt", padding=True, truncation=True
105103
)["input_ids"].to(self.model.device)
106104

107-
# 批量获取模型输出
108105
with torch.no_grad():
109106
outputs = self.model(input_ids=input_ids)
110107

111-
# 计算每个样本的步骤奖励
112108
step_sep_id = self.tokenizer.encode("<extra_0>")[0]
113109
token_masks = (input_ids == step_sep_id)
114110
batch_rewards = self.make_step_rewards(outputs[0], token_masks)

ding/reward_model/math_rule_reward_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def _process_response_answer(self, response: str) -> Tuple[Optional[float], Opti
112112
if self.logger:
113113
self.logger.debug(f"Error processing expression '{expr}': {e}")
114114

115-
# If all attempts fail, return None
116115
return None, None
117116

118117
def _check_answer_match(self, pred: Optional[float], target: Optional[float]) -> bool:
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from typing import List, Dict
2+
from easydict import EasyDict
3+
from torch.utils.tensorboard import SummaryWriter
4+
from transformers import AutoTokenizer, AutoModelForCausalLM
5+
import torch
6+
from ding.utils import REWARD_MODEL_REGISTRY
7+
from .base_reward_model import BaseRewardModel
8+
9+
10+
@REWARD_MODEL_REGISTRY.register('multi_modal')
11+
class MultiModalRewardModel(BaseRewardModel):
12+
config = dict(
13+
type='multi_modal',
14+
model_name='internlm/internlm-xcomposer2d5-7b-reward',
15+
hd_num=9, # Number of high-definition patches for image processing
16+
)
17+
18+
def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None:
19+
self.cfg = config
20+
self.device = device
21+
self.logger = logger
22+
self.tb_logger = tb_logger
23+
24+
self.tokenizer = AutoTokenizer.from_pretrained(
25+
self.cfg.model_name, trust_remote_code=True, local_files_only=True
26+
)
27+
self.model = AutoModelForCausalLM.from_pretrained(
28+
self.cfg.model_name, torch_dtype=torch.float16, trust_remote_code=True
29+
)
30+
31+
self.model.tokenizer = self.tokenizer
32+
self.model.cuda().eval()
33+
34+
def estimate(self, data: List[Dict], image: List[str], output_mode: str = 'score') -> List[Dict]:
35+
"""
36+
Estimate rewards for multi-modal inputs using internlm-xcomposer model.
37+
38+
Arguments:
39+
data (List[Dict]): List of chat dictionaries, each containing:
40+
- chat (List[Dict]): List of messages, each message is a dict with:
41+
- role (str): Either "user" or "assistant"
42+
- content (str): The message content
43+
image (List[str]): List of image paths. If fewer images than chats, last image will be reused
44+
output_mode (str, optional): Evaluation mode. Defaults to 'score'.
45+
- 'score': Return reward scores for each chat
46+
- 'rank': Return ranking indices (0 is best) for all chats
47+
- 'compare': Compare first two chats (returns 1.0 for better, 0.0 for worse)
48+
49+
Returns:
50+
List[Dict]: Results depending on output_mode:
51+
- For 'score' mode:
52+
[{
53+
'reward': float, # Reward score
54+
'metadata': {
55+
'mode': 'score',
56+
'chat_idx': int, # Index of the chat
57+
'image_path': str # Path of the image used
58+
}
59+
}, ...]
60+
- For 'rank' mode:
61+
[{
62+
'rank': int, # Ranking position (0 is best)
63+
'metadata': {
64+
'mode': 'rank',
65+
'chat_idx': int,
66+
'image_path': str
67+
}
68+
}, ...]
69+
- For 'compare' mode:
70+
[{
71+
'reward': float, # 1.0 for better, 0.0 for worse
72+
'metadata': {
73+
'mode': 'compare',
74+
'chat_idx': int,
75+
'image_path': str,
76+
'compared_with': int # Index of the compared chat
77+
}
78+
}, ...]
79+
"""
80+
# Get chat data
81+
chats = [item['chat'] for item in data]
82+
83+
with torch.autocast(device_type='cuda', dtype=torch.float16):
84+
if output_mode == 'score':
85+
# Ensure each chat has a corresponding image, use the last image if not enough
86+
if len(image) < len(chats):
87+
image = image + [image[-1]] * (len(chats) - len(image))
88+
89+
# Get scores for each chat
90+
scores = []
91+
for chat, img in zip(chats, image):
92+
score = self.model.get_score(chat, [img], hd_num=self.cfg.hd_num)
93+
scores.append(score)
94+
95+
return [
96+
{
97+
'reward': float(score),
98+
'metadata': {
99+
'mode': 'score',
100+
'chat_idx': idx,
101+
'image_path': img
102+
}
103+
} for idx, (score, img) in enumerate(zip(scores, image))
104+
]
105+
106+
elif output_mode == 'rank':
107+
# Use the first image for ranking
108+
img = image[0]
109+
ranks = self.model.rank(chats, [[img]] * len(chats), hd_num=self.cfg.hd_num)
110+
111+
return [
112+
{
113+
'rank': int(rank),
114+
'metadata': {
115+
'mode': 'rank',
116+
'chat_idx': idx,
117+
'image_path': img
118+
}
119+
} for idx, rank in enumerate(ranks)
120+
]
121+
122+
elif output_mode == 'compare':
123+
if len(data) < 2:
124+
raise ValueError("Compare mode requires at least 2 samples")
125+
126+
# Use the first image for comparison
127+
img = image[0]
128+
is_better = self.model.compare(chats[0], [img], chats[1], [img], hd_num=self.cfg.hd_num)
129+
130+
return [
131+
{
132+
'reward': 1.0 if is_better else 0.0,
133+
'metadata': {
134+
'mode': 'compare',
135+
'chat_idx': 0,
136+
'image_path': img,
137+
'compared_with': 1
138+
}
139+
}, {
140+
'reward': 0.0 if is_better else 1.0,
141+
'metadata': {
142+
'mode': 'compare',
143+
'chat_idx': 1,
144+
'image_path': img,
145+
'compared_with': 0
146+
}
147+
}
148+
]
149+
else:
150+
raise ValueError(f"Invalid output mode: {output_mode}")
151+
152+
def train(self):
153+
"""Training is not implemented for this reward model"""
154+
self.logger.warning("Training is not implemented for this reward model")
155+
pass
156+
157+
def collect_data(self, data: list) -> None:
158+
"""Data collection is not needed for this reward model"""
159+
pass
160+
161+
def clear_data(self) -> None:
162+
"""Data clearing is not needed for this reward model"""
163+
pass

ding/reward_model/tests/test_math_reward_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_math_reward_model():
2121
# Initialize reward model
2222
model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger)
2323

24-
# Test case 1: Simple math problem
24+
# Simple math problem
2525
data_simple = [
2626
{
2727
"system": "Please reason step by step...",
@@ -30,7 +30,7 @@ def test_math_reward_model():
3030
}
3131
]
3232

33-
# Test case 2: Complex word problem
33+
# Complex word problem
3434
data_complex = [
3535
{
3636
"system": "Please reason step by step, and put your final answer within \\boxed{}.",

ding/reward_model/tests/test_math_rule_reward_model.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from easydict import EasyDict
55
from ding.reward_model.math_rule_reward_model import MathRuleRewardModel
66

7+
78
@pytest.fixture
89
def reward_model():
910
return MathRuleRewardModel(
@@ -19,24 +20,26 @@ def reward_model():
1920

2021
@pytest.mark.envtest
2122
def test_math_rule_reward_model_correct_answer(reward_model):
22-
data_correct = [{
23-
"system": "Please answer this math problem...",
24-
"query": (
25-
"The school now introduces a new color, silver, for the flag design. "
26-
"Crestview's school colors are now purple, gold, and silver. "
27-
"The students are designing a flag using three solid-colored horizontal stripes. "
28-
"Using one, two, or all three of the school colors, how many different flags "
29-
"are possible if adjacent stripes may be the same color?"
30-
),
31-
"response": (
32-
"Crestview's school colors—purple, gold, and silver—can be used to design "
33-
"a flag with three horizontal stripes, where each stripe can be any of the "
34-
"three colors and adjacent stripes may be the same. Since each of the three "
35-
"stripes has three independent color choices, the total number of possible "
36-
"flag designs is 27"
37-
),
38-
"answer": r"27"
39-
}]
23+
data_correct = [
24+
{
25+
"system": "Please answer this math problem...",
26+
"query": (
27+
"The school now introduces a new color, silver, for the flag design. "
28+
"Crestview's school colors are now purple, gold, and silver. "
29+
"The students are designing a flag using three solid-colored horizontal stripes. "
30+
"Using one, two, or all three of the school colors, how many different flags "
31+
"are possible if adjacent stripes may be the same color?"
32+
),
33+
"response": (
34+
"Crestview's school colors—purple, gold, and silver—can be used to design "
35+
"a flag with three horizontal stripes, where each stripe can be any of the "
36+
"three colors and adjacent stripes may be the same. Since each of the three "
37+
"stripes has three independent color choices, the total number of possible "
38+
"flag designs is 27"
39+
),
40+
"answer": r"27"
41+
}
42+
]
4043

4144
# Test the case with correct answer
4245
rewards = reward_model.estimate(data_correct)
@@ -48,26 +51,28 @@ def test_math_rule_reward_model_correct_answer(reward_model):
4851

4952
@pytest.mark.envtest
5053
def test_math_rule_reward_model_wrong_answer(reward_model):
51-
data_wrong = [{
52-
"system": "Please answer this math problem...",
53-
"query": (
54-
"The school now introduces a new color, silver, for the flag design. "
55-
"Crestview's school colors are now purple, gold, and silver. "
56-
"The students are designing a flag using three solid-colored horizontal stripes. "
57-
"Using one, two, or all three of the school colors, how many different flags "
58-
"are possible if adjacent stripes may be the same color?"
59-
),
60-
"response": (
61-
r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on "
62-
r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, "
63-
r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and "
64-
r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the "
65-
r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). "
66-
r"Therefore, the smallest positive value of \(\alpha\) is "
67-
r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)."
68-
),
69-
"answer": r"\frac{11\pi}{6}"
70-
}]
54+
data_wrong = [
55+
{
56+
"system": "Please answer this math problem...",
57+
"query": (
58+
"The school now introduces a new color, silver, for the flag design. "
59+
"Crestview's school colors are now purple, gold, and silver. "
60+
"The students are designing a flag using three solid-colored horizontal stripes. "
61+
"Using one, two, or all three of the school colors, how many different flags "
62+
"are possible if adjacent stripes may be the same color?"
63+
),
64+
"response": (
65+
r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on "
66+
r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, "
67+
r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and "
68+
r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the "
69+
r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). "
70+
r"Therefore, the smallest positive value of \(\alpha\) is "
71+
r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)."
72+
),
73+
"answer": r"\frac{11\pi}{6}"
74+
}
75+
]
7176

7277
# Test the case with wrong answer
7378
rewards = reward_model.estimate(data_wrong)
@@ -79,12 +84,14 @@ def test_math_rule_reward_model_wrong_answer(reward_model):
7984

8085
@pytest.mark.envtest
8186
def test_math_rule_reward_model_format_error(reward_model):
82-
data_format_error = [{
83-
"system": "Please answer this math problem...",
84-
"query": "What is 2+2?",
85-
"response": "The answer is four.",
86-
"answer": r"4"
87-
}]
87+
data_format_error = [
88+
{
89+
"system": "Please answer this math problem...",
90+
"query": "What is 2+2?",
91+
"response": "The answer is four.",
92+
"answer": r"4"
93+
}
94+
]
8895
rewards_format = reward_model.estimate(data_format_error)
8996
assert len(rewards_format) == len(data_format_error)
9097
# This should be a format error because "four" cannot be processed as a numerical value
@@ -99,13 +106,11 @@ def test_math_rule_reward_model_special_expressions(reward_model):
99106
"query": "What is 1/2?",
100107
"response": r"The answer is \frac{1}{2}.",
101108
"answer": r"0.5"
102-
},
103-
{
109+
}, {
104110
"query": "What is 50%?",
105111
"response": "The answer is 50%.",
106112
"answer": r"0.5"
107-
},
108-
{
113+
}, {
109114
"query": "What is sqrt(4)?",
110115
"response": r"The answer is \sqrt{4} = 2.",
111116
"answer": r"2"

0 commit comments

Comments
 (0)