Skip to content

Commit 424de73

Browse files
Best of N Sampling with PRM support
1 parent d0089d1 commit 424de73

File tree

5 files changed

+613
-0
lines changed

5 files changed

+613
-0
lines changed

docs/examples/best_of_n/prm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Example of Using Best of N with PRMs"""
2+
3+
from docs.examples.helper import w
4+
from mellea import start_session
5+
from mellea.backends.types import ModelOption
6+
from mellea.stdlib.rewards.prm_scorer import PRMScorer
7+
from mellea.stdlib.sampling import BestofNSamplingStrategy
8+
9+
# create a session using Granite 3.3 8B on Huggingface and a simple context [see below]
10+
m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 1024})
11+
12+
# create PRM scorer object
13+
prm = PRMScorer(
14+
model_version="ibm-granite/granite-3.3-8b-lora-math-prm",
15+
prm_type="generative",
16+
correct_token="Y",
17+
generation_prompt="Is this response correct so far (Y/N)?",
18+
step_splitter="\n\n",
19+
)
20+
21+
# Do Best of N sampling with the PRM scorer
22+
BoN_prm = m.instruct(
23+
"Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?",
24+
strategy=BestofNSamplingStrategy(loop_budget=3, requirements=[prm]),
25+
model_options={"temperature": 0.9, "do_sample": True},
26+
)
27+
28+
# print result
29+
print(f"***** BoN ****\n{w(BoN_prm)}\n*******")

mellea/stdlib/rewards/__init__.py

Whitespace-only changes.

mellea/stdlib/rewards/prm.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
5+
class GenerativePRMForInference(torch.nn.Module):
6+
"""
7+
Class for Generative Process Reward Models for Inference
8+
Uses Huggingface backend to load the model (which is trained using LoRA adapters)
9+
"""
10+
11+
def __init__(
12+
self,
13+
model_path="ibm-granite/granite-3.3-8b-lora-math-prm",
14+
correct_token="Y",
15+
generation_prompt="Is this response correct so far (Y/N)?",
16+
load_in_bf16=True,
17+
device=None,
18+
) -> None:
19+
super().__init__()
20+
21+
if not load_in_bf16:
22+
self.model = AutoModelForCausalLM.from_pretrained(
23+
model_path, device_map="auto"
24+
)
25+
else:
26+
self.model = AutoModelForCausalLM.from_pretrained(
27+
model_path, torch_dtype=torch.bfloat16, device_map="auto"
28+
)
29+
30+
if device is not None:
31+
self.model.to(device)
32+
self.device = self.model.device
33+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
34+
self.tokenizer.truncation_side = "left" # prevents truncation from right (default): needed since we always want to have the last step and last generation prompt from the context.
35+
self.correct_token = correct_token
36+
self.correct_token_id = self.tokenizer.encode(
37+
self.correct_token, add_special_tokens=False
38+
)[0]
39+
self.generation_prompt = generation_prompt
40+
self.softmax = torch.nn.Softmax(dim=-1)
41+
42+
def forward(self, raw_inputs):
43+
"""
44+
Expects a raw_batch of (questions: List[str], steps: List[List[str]])
45+
Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores
46+
"""
47+
48+
# get un-tokenized batch
49+
batches = self.prepare_batch(raw_inputs)
50+
# each element of the batch consists of a list of num_steps messages corresponding to a single input, which we need to handle
51+
all_rewards = []
52+
all_rewards_per_step = []
53+
54+
chat_template_to_turn = self.tokenizer.apply_chat_template(
55+
[{"role": "assistant", "content": self.correct_token}],
56+
tokenize=False,
57+
add_generation_prompt=False,
58+
)
59+
if "system" in chat_template_to_turn:
60+
if "granite" in self.model.config.model_type.lower():
61+
# for granite, apply_chat_template also adds system prompt
62+
asst_text = (
63+
"<|start_of_role|>assistant<|end_of_role|>"
64+
+ self.correct_token
65+
+ "<|end_of_text|>"
66+
)
67+
elif "phi" in self.model.config.model_type.lower():
68+
# phi reasoning also applies the system prompt
69+
asst_text = (
70+
"<|im_start|>assistant<|im_sep|>"
71+
+ self.correct_token
72+
+ "<|im_end|>'"
73+
)
74+
else:
75+
asst_text = chat_template_to_turn
76+
asst_toks = self.tokenizer(
77+
asst_text, add_special_tokens=False, return_tensors="pt"
78+
)["input_ids"][0]
79+
asst_toks_before_correct_token = asst_toks[
80+
: torch.where(asst_toks == self.correct_token_id)[0].item()
81+
].tolist()
82+
83+
# each element in batch contains a question and the response
84+
for i in batches:
85+
batches[i] = batches[i].to(self.model.device)
86+
87+
with torch.no_grad():
88+
model_outputs = self.model(**batches)
89+
logits = model_outputs.logits # (bsz, seq_len, vocab_size)
90+
91+
for batch_idx in range(logits.shape[0]):
92+
per_input_rewards = []
93+
# for each element in the batch (i.e., each input)
94+
# we need to get logits for all tokens where the token in "Y" (in assistant turn)
95+
# find batch index for assistant turn "Y", not just the correct_token_id
96+
correct_token_indices = torch.where(
97+
batches["input_ids"][batch_idx] == self.correct_token_id
98+
)[0].tolist()
99+
prm_indices = []
100+
for t_idx in correct_token_indices:
101+
if (
102+
batches["input_ids"][batch_idx][
103+
t_idx - len(asst_toks_before_correct_token) : t_idx
104+
].tolist()
105+
== asst_toks_before_correct_token
106+
):
107+
prm_indices.append(
108+
t_idx - 1
109+
) # the logits for token i predict the token i+1: so, we need to look at the PREVIOUS token logits
110+
111+
assert len(prm_indices) > 0
112+
# convert logits to probabilities and get the probability of the correct token id as reward
113+
for prm_idx in prm_indices:
114+
per_input_rewards.append(
115+
self.softmax(logits[batch_idx, prm_idx, :])[
116+
self.correct_token_id
117+
].item()
118+
)
119+
120+
# aggregate. return final rewards
121+
all_rewards_per_step.append(per_input_rewards)
122+
sum = 0
123+
for reward in per_input_rewards:
124+
sum += reward
125+
per_input_reward = sum / len(per_input_rewards)
126+
all_rewards.append(per_input_reward)
127+
128+
return all_rewards, all_rewards_per_step
129+
130+
def prepare_batch(self, raw_batch):
131+
"""
132+
Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token
133+
prepare_batch() function splits each step into an individual response, and prepares an input batch
134+
prepare batch for forward pass
135+
"""
136+
137+
questions, list_of_steps = raw_batch
138+
assert len(questions) == len(list_of_steps)
139+
140+
inputs = []
141+
for i in range(len(questions)):
142+
user_content = questions[i]
143+
steps = list_of_steps[i]
144+
msgs = []
145+
for s_idx, step in enumerate(steps):
146+
# apply chat template as expected by RM input
147+
if s_idx == 0:
148+
msgs.append(
149+
{
150+
"role": "user",
151+
"content": user_content
152+
+ " "
153+
+ step
154+
+ " "
155+
+ self.generation_prompt,
156+
}
157+
)
158+
else:
159+
# first add last assistant turn
160+
msgs.append({"role": "assistant", "content": self.correct_token})
161+
msgs.append(
162+
{"role": "user", "content": step + " " + self.generation_prompt}
163+
)
164+
165+
# append the last asst turn
166+
msgs.append({"role": "assistant", "content": self.correct_token})
167+
168+
input_message = self.tokenizer.apply_chat_template(
169+
msgs, add_generation_prompt=False, tokenize=False
170+
)
171+
172+
inputs.append(input_message)
173+
174+
return self.tokenizer(
175+
inputs, return_tensors="pt", padding=True, truncation=True
176+
)
177+
178+
179+
class RegressionPRMForInference(torch.nn.Module):
180+
"""
181+
Class for Regression (non-generative) Process Reward Models for Inference
182+
Uses Huggingface backend to load the model
183+
All regression process reward models trained by the GMA team at IBM research use a special step token, <end_of_step>
184+
"""
185+
186+
def __init__(
187+
self,
188+
model_path: str,
189+
step_eos: str = "<end_of_step>",
190+
load_in_bf16: bool = True,
191+
device=None,
192+
) -> None:
193+
super().__init__()
194+
195+
# Load the model
196+
self.model: AutoModelForCausalLM
197+
if not load_in_bf16:
198+
self.model = AutoModelForCausalLM.from_pretrained( # type: ignore
199+
model_path, device_map="auto"
200+
)
201+
else:
202+
self.model = AutoModelForCausalLM.from_pretrained( # type: ignore
203+
model_path, torch_dtype=torch.bfloat16, device_map="auto"
204+
)
205+
self.device = self.model.device
206+
self.config = self.model.config
207+
208+
# get the token IDs for the step separator token
209+
self.step_eos = step_eos
210+
self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_path)
211+
# self.tokenizer.add_tokens(self.step_eos)
212+
self.step_eos_id = self.tokenizer.encode(
213+
self.step_eos, add_special_tokens=False
214+
)[0]
215+
216+
# load the PRM head
217+
self.prm_head = torch.nn.Linear(
218+
self.model.config.hidden_size, 2, bias=False, dtype=self.model.dtype
219+
).to(self.model.device)
220+
state = torch.load(model_path + "/added_params.bin")
221+
self.load_state_dict(state, strict=False)
222+
self.model.eval()
223+
224+
self.softmax = torch.nn.Softmax(dim=-1)
225+
226+
def forward(self, raw_batch):
227+
"""
228+
Expects a raw_batch of (questions: List[str], steps: List[List[str]])
229+
Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores
230+
"""
231+
232+
# tokenizes the batch and concatenates the list of steps into a single step-separated response
233+
batch = self.prepare_batch(raw_batch).to(self.device)
234+
235+
with torch.no_grad():
236+
model_outputs = self.model(**batch, output_hidden_states=True)
237+
# all logits
238+
all_prm_logits = self.prm_head(model_outputs["hidden_states"][-1]).squeeze(
239+
-1
240+
)
241+
242+
# get logits for each end of step i.e. logits for step_eos positions in the input
243+
prm_probs = []
244+
rewards = []
245+
for idx in range(all_prm_logits.shape[0]):
246+
prm_indices = torch.where(batch["input_ids"][idx] == self.step_eos_id)[0]
247+
if prm_indices.shape[0] == 0:
248+
# no match found-- model did not produce outputs in correct step-wise format
249+
prm_probs.append([None])
250+
reward = None
251+
else:
252+
# head produces two logits, the second one is the logit for the correct answer
253+
# convert logits to probabilities using softmax
254+
# return list of floats instead of list of tensors
255+
prm_probs_per_sample = [
256+
t.item()
257+
for t in self.softmax(all_prm_logits[idx][prm_indices])[:, 1]
258+
]
259+
prm_probs.append(prm_probs_per_sample)
260+
261+
reward = sum(prm_probs_per_sample) / len(prm_probs_per_sample)
262+
rewards.append(reward)
263+
264+
return rewards, prm_probs
265+
266+
def prepare_batch(self, raw_batch):
267+
"""
268+
Tokenize and prepare batch for forward pass
269+
Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token
270+
"""
271+
272+
questions, list_of_steps = raw_batch
273+
assert len(questions) == len(list_of_steps)
274+
275+
inputs = []
276+
for i in range(len(questions)):
277+
text_with_steps_marked = ""
278+
279+
for step in list_of_steps[i]:
280+
text_with_steps_marked += f"{step} {self.step_eos}"
281+
282+
message = [
283+
{"role": "user", "content": questions[i]},
284+
{"role": "assistant", "content": text_with_steps_marked},
285+
]
286+
input = self.tokenizer.apply_chat_template(message, tokenize=False)
287+
inputs.append(input)
288+
289+
# tokenize data for the RM
290+
batch = self.tokenizer(
291+
inputs, return_tensors="pt", padding=True, truncation=True
292+
)
293+
return batch

0 commit comments

Comments
 (0)