-
Notifications
You must be signed in to change notification settings - Fork 53
feat: Best-of-N Sampling with Process Reward Models #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Best-of-N Sampling with Process Reward Models #118
Conversation
aashka-trivedi
commented
Sep 3, 2025
- Adds Best-of-N Sampling methods. Currently supports only a single Requirement that acts as a scorer for Best of N
- Adds support for PRMs (Generative and Regression PRM) with examples on how to use it with a BoN Sampler
- Added optional input to session validate() function to allow for validations that need to take the model input into account (eg. PRMs)
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🟢 Enforce conventional commitWonderful, this rule succeeded.Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/
|
mellea/stdlib/rewards/prm.py
Outdated
| + "<|im_end|>'" | ||
| ) | ||
| else: | ||
| asst_text = chat_template_to_turn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Can you add a comment here explaining why this is the right thing to do?
- Given that granite and phi have special-purpose logic, should we be at least
logger.warn()'ing here?
mellea/stdlib/rewards/prm.py
Outdated
| import torch | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
|
|
||
| class GenerativePRMForInference(torch.nn.Module): | ||
| """ | ||
| Class for Generative Process Reward Models for Inference | ||
| Uses Huggingface backend to load the model (which is trained using LoRA adapters) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_path="ibm-granite/granite-3.3-8b-lora-math-prm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of our design goals is that that everything in stdlib should be "targetable" to models/inference engines. I.e., we want a pretty clean separation between the abstractions in stdlib and the implementation details re: how those specific abstractions get interpreted in a specific model size of a specific model version-set of a specific model.
This means that we should avoid hard-coding anything in stdlib to either huggingface or a special model_id.
For example, we separate the implementation of constraint checking in back-ends from the Requirement interface in stdlib. Requirement checking using Granite 3.2 or 3.3 in huggingface or vllm uses the constraint-checking LoRA. For any other model family, or any other version-set/size of those models, or any other inference engine, we have a prompt-based method for doing the same.
In the case of requirement checking, there's an at-least-half-sensible approach toward requirement actions on inference engines and models that don't support the constraint checking LoRA. But it's okay not to have a "default fallback" in the way that requirement checking does. You can just throw NotImplementedError if the inference engine or model ID doesn't provide the appropriate functionality.
I do not have a solution ready for you in-hand, but here's one possibility to consider:
- Create a
ProcessRewardModelthing (probably aComponent?) instdlibthat has an abstract interface which elides the exact implementation details at the level of models. - You will need some specific functionality of a Backend in order to implement that Component's actual behavior. For this purpose, create a
PRMBackendMixinclass inmellea.backends. - Add that Mixin to the Huggingface backend and implement the associated methods. Then move this model and hf-specific code into that implementation of the mixin.
- If it's possible, model-specific stuff should be additionally factored out (perhaps into Formatter?). Again, a mixin approach can be used if absolutely necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meta: we're open to being wrong about how this sort of thing should be implemented in application-layer libraries. Is this sort of thing implemented in langchain, bee, dspy, etc? If so, we can look at their architectural choices and consider those as alternative options to the above proposal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these goals make sense to me. However, in my experience PRM implementations are not very "standard", i.e., it really depends on the model training paradigm. But, I do like the idea of maybe creating a ProcessRewardModel Component, and having classes that implement the forward such that it matches the model-specifics. Currently, Im only implementing model classes that are compatible with the PRMs that IBM puts out, and we can leave other PRM implementations as a future contribution.
Another suggestion is to make a Scorer base class that is similar to a Requirement except that instead of returning a bool indicating Requirement satisfaction, it returns a Score. This class can then call any backend/object that has a scorer function. This idea may need a much larger PR (maybe needs more thinking from a design perspective), but it would make BestOfNSampling cleaner if there are multiple requirements and a scorer to choose the best. In the context of your suggestion, the ProcessRewardModel would be an instance of the Scorer class, which then calls the associated backend to get the responses.
I can implement the ProcessRewardModel to be its own component for the purpose of this PR, and Ill leave it up to y'all to decide what to do about these types of "requirements" :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of sync convo:
We will refactor as follows:
- Add a
Scorer :> Requirementcomponent - add the PRMBackendMixin and implement for Huggingface
- add the PRM itself that uses this stuff
- test :)
stdlib
- We should add a
Scorereither as a mixin oras a subclass ofRequirement. - The
Scorer'sscoremethod must return a non-None value. - The
Scorermust also define a preference ordering (is the goal to min the score or max the score? If the scores are categorical what's the preference ordering?)
Adding this to mellea.stdlib.requirements is justified because it allows implementers of SamplingStrategies to differentiate between "normal" requirements and "ranking" requirements.
Backend
Introduce a new Mixin for PRM stuff:
class PRMBackendMixin:
def prm_act(action: PRMBlah):
...which then gets added to each backend that supports PRM:
class LocalHFBackend(FormatterBackend, AloraBackendMixin, PRMBackendMixin):We now need a way for mellea.stdlib.requirements.PRMScorer to use the PRMBackendMixin. That can be done in the Scorer.validate. Now, when PRMScorer.validate(...,b: Backend,...) gets called we can check that b has the PRMBackendMixin and if not raise a NotImplementedError.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of second sync convo:
- We will add a Scorer:> Requirement as discussed above
- Separate the PRM model from the backend entirely: add it as a separate component
- Implement PRM Model classes for Huggingface/VLLM etc (in stdlib)
- The PRMScorer gets passed this PRM object
|
Thanks for the contribution! The pre-commit checks found some type errors; could you please fix these prior to a code review? For these errors, appropriate fix is probably to just add some FYI: you can run these checks locally by installing the pre-commit hooks. Assuming you have already created a venv and installed Mellea editable ( Once installed, you can still commit over errors using the |
|
#104 seems related (you can replace "aLoRA" with "LoRA" in the description on that issue -- same idea). |
|
Hi @nrfulton This branch is ready for review. As discussed, this PR implements the following:
|