Skip to content

Commit b18e03d

Browse files
feat: Best-of-N Sampling with Process Reward Models (#118)
* Implements best-of-N Sampling with PRM support. * Backends: adds specific classes for querying PRMs. * Standard Library: adds a ScorerRequirement class which serves as the Requirement interface to PRMs. * Provides a concrete implementation of these abstract interfaces using the HuggingFace inference engine and the `ibm-granite/granite-3.3-8b-lora-math-prm` PRM model.
1 parent e2746a1 commit b18e03d

File tree

11 files changed

+752
-8
lines changed

11 files changed

+752
-8
lines changed

docs/examples/best_of_n/prm.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Example of Using Best of N with PRMs"""
2+
3+
from docs.examples.helper import w
4+
from mellea import start_session
5+
from mellea.backends.process_reward_models.huggingface.prms import (
6+
HFGenerativePRM,
7+
HFRegressionPRM,
8+
)
9+
from mellea.backends.types import ModelOption
10+
from mellea.stdlib.rewards.prm_scorer import PRMScorer
11+
from mellea.stdlib.sampling import BestofNSamplingStrategy
12+
13+
# create a session for the generator using Granite 3.3 8B on Huggingface and a simple context [see below]
14+
m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 512})
15+
16+
# initialize the PRM model
17+
prm_model = HFGenerativePRM(
18+
model_name_or_path="ibm-granite/granite-3.3-8b-lora-math-prm",
19+
score_token="Y",
20+
generation_prompt="Is this response correct so far (Y/N)?",
21+
step_separator="\n\n",
22+
)
23+
24+
# # can also initialize a Regression PRM model
25+
# prm_model = HFRegressionPRM(
26+
# model_name_or_path = "granite-3.3-8b-math-prm-regression",
27+
# score_token= "<end_of_step>",
28+
# step_separator= "\n\n")
29+
30+
# create PRM scorer object
31+
prm = PRMScorer(prm_model=prm_model, preference_ordering="max")
32+
33+
# Do Best of N sampling with the PRM scorer and an additional requirement
34+
BoN_prm = m.instruct(
35+
"Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?",
36+
strategy=BestofNSamplingStrategy(loop_budget=3),
37+
model_options={"temperature": 0.9, "do_sample": True},
38+
requirements=["provide final answer like 'Final Answer:'", prm],
39+
)
40+
41+
# print result
42+
print(f"***** BoN ****\n{w(BoN_prm)}\n*******")

mellea/backends/huggingface.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mellea.backends.cache import Cache, SimpleLRUCache
3131
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
3232
from mellea.backends.model_ids import ModelIdentifier
33+
from mellea.backends.process_reward_models import PRM
3334
from mellea.backends.tools import (
3435
add_tools_from_context_actions,
3536
add_tools_from_model_options,
@@ -672,3 +673,56 @@ def __init__(
672673
self._generation_prompt_tokens = self._backend._tokenizer(
673674
self._generation_prompt, return_tensors="pt"
674675
).to(self._backend._device)
676+
677+
678+
class HFProcessRewardModel(PRM, abc.ABC):
679+
def __init__(
680+
self, model_name_or_path: str, score_token: str, device: str | None = None
681+
):
682+
"""Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models
683+
684+
Args:
685+
model_name_or_path (str): A local path to PRM or a huggingface PRM
686+
score_token (str): token who's logits correspond to the PRM score. Can be a step demarker (for non-generative PRMs) or a correctness indicator (for generative PRMs)
687+
device (str): device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected.
688+
"""
689+
super().__init__(model_name_or_path)
690+
691+
# auto-device if not more specific
692+
self._device = device
693+
if device is None:
694+
device_name: str = (
695+
"cuda"
696+
if torch.cuda.is_available()
697+
else "mps"
698+
if torch.backends.mps.is_available()
699+
else "cpu"
700+
)
701+
assert device_name is not None
702+
self._device = torch.device(device_name) # type: ignore
703+
704+
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
705+
self.model_name_or_path, torch_dtype=torch.bfloat16
706+
)
707+
self.model.to(self._device) # type: ignore
708+
self.model.eval()
709+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
710+
711+
self._score_token = score_token
712+
self._score_token_id = self.tokenizer.encode(
713+
self._score_token, add_special_tokens=False
714+
)[0]
715+
716+
def stepify(self, content: str, step_separator: str) -> list[str]:
717+
"""Splits the assistant response into steps to score
718+
719+
Args:
720+
content: assistant response to score
721+
step_separator: string on which to separate the content into steps
722+
"""
723+
724+
# convert assistant message into a list of steps
725+
list_of_steps = [
726+
step.strip() for step in content.split(step_separator) if step.strip != ""
727+
]
728+
return list_of_steps

mellea/backends/litellm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from collections.abc import Callable
66
from typing import Any
77

8-
import litellm
9-
import litellm.litellm_core_utils
10-
import litellm.litellm_core_utils.get_supported_openai_params
8+
import litellm # type: ignore
9+
import litellm.litellm_core_utils # type: ignore
10+
import litellm.litellm_core_utils.get_supported_openai_params # type: ignore
1111

1212
import mellea.backends.model_ids as model_ids
1313
from mellea.backends import BaseModelSubclass
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers)"""
2+
3+
import abc
4+
5+
6+
class PRM(abc.ABC):
7+
def __init__(self, model_name_or_path):
8+
# Leave implementation of model to inheriting class
9+
self.model_name_or_path = model_name_or_path
10+
11+
@abc.abstractmethod
12+
def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]:
13+
"""Returns a final score and per-step score to the input of the model"""
14+
...
15+
16+
@abc.abstractmethod
17+
def stepify(self, response: str, step_separator: str) -> list[str]:
18+
"""Splits the assistant response into steps to score
19+
20+
Args:
21+
response: assistant response to score
22+
step_separator: string on which to separate the response into steps
23+
"""
24+
...
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Process Reward Model Implementations with Huggingface backends"""

0 commit comments

Comments
 (0)