Skip to content

Conversation

@aashka-trivedi
Copy link
Contributor

  • Adds Best-of-N Sampling methods. Currently supports only a single Requirement that acts as a scorer for Best of N
  • Adds support for PRMs (Generative and Regression PRM) with examples on how to use it with a BoN Sampler
  • Added optional input to session validate() function to allow for validations that need to take the model input into account (eg. PRMs)

@mergify
Copy link

mergify bot commented Sep 3, 2025

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🟢 Enforce conventional commit

Wonderful, this rule succeeded.

Make sure that we follow https://www.conventionalcommits.org/en/v1.0.0/

  • title ~= ^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert)(?:\(.+\))?:

@nrfulton nrfulton changed the title Best-of-N Sampling with Process Reward Models [feat] Best-of-N Sampling with Process Reward Models Sep 4, 2025
+ "<|im_end|>'"
)
else:
asst_text = chat_template_to_turn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can you add a comment here explaining why this is the right thing to do?
  2. Given that granite and phi have special-purpose logic, should we be at least logger.warn()'ing here?

Comment on lines 1 to 13
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class GenerativePRMForInference(torch.nn.Module):
"""
Class for Generative Process Reward Models for Inference
Uses Huggingface backend to load the model (which is trained using LoRA adapters)
"""

def __init__(
self,
model_path="ibm-granite/granite-3.3-8b-lora-math-prm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of our design goals is that that everything in stdlib should be "targetable" to models/inference engines. I.e., we want a pretty clean separation between the abstractions in stdlib and the implementation details re: how those specific abstractions get interpreted in a specific model size of a specific model version-set of a specific model.

This means that we should avoid hard-coding anything in stdlib to either huggingface or a special model_id.

For example, we separate the implementation of constraint checking in back-ends from the Requirement interface in stdlib. Requirement checking using Granite 3.2 or 3.3 in huggingface or vllm uses the constraint-checking LoRA. For any other model family, or any other version-set/size of those models, or any other inference engine, we have a prompt-based method for doing the same.

In the case of requirement checking, there's an at-least-half-sensible approach toward requirement actions on inference engines and models that don't support the constraint checking LoRA. But it's okay not to have a "default fallback" in the way that requirement checking does. You can just throw NotImplementedError if the inference engine or model ID doesn't provide the appropriate functionality.

I do not have a solution ready for you in-hand, but here's one possibility to consider:

  1. Create a ProcessRewardModel thing (probably a Component?) in stdlib that has an abstract interface which elides the exact implementation details at the level of models.
  2. You will need some specific functionality of a Backend in order to implement that Component's actual behavior. For this purpose, create a PRMBackendMixin class in mellea.backends.
  3. Add that Mixin to the Huggingface backend and implement the associated methods. Then move this model and hf-specific code into that implementation of the mixin.
  4. If it's possible, model-specific stuff should be additionally factored out (perhaps into Formatter?). Again, a mixin approach can be used if absolutely necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta: we're open to being wrong about how this sort of thing should be implemented in application-layer libraries. Is this sort of thing implemented in langchain, bee, dspy, etc? If so, we can look at their architectural choices and consider those as alternative options to the above proposal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these goals make sense to me. However, in my experience PRM implementations are not very "standard", i.e., it really depends on the model training paradigm. But, I do like the idea of maybe creating a ProcessRewardModel Component, and having classes that implement the forward such that it matches the model-specifics. Currently, Im only implementing model classes that are compatible with the PRMs that IBM puts out, and we can leave other PRM implementations as a future contribution.

Another suggestion is to make a Scorer base class that is similar to a Requirement except that instead of returning a bool indicating Requirement satisfaction, it returns a Score. This class can then call any backend/object that has a scorer function. This idea may need a much larger PR (maybe needs more thinking from a design perspective), but it would make BestOfNSampling cleaner if there are multiple requirements and a scorer to choose the best. In the context of your suggestion, the ProcessRewardModel would be an instance of the Scorer class, which then calls the associated backend to get the responses.

I can implement the ProcessRewardModel to be its own component for the purpose of this PR, and Ill leave it up to y'all to decide what to do about these types of "requirements" :)

Copy link
Contributor

@nrfulton nrfulton Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of sync convo:

We will refactor as follows:

  1. Add a Scorer :> Requirement component
  2. add the PRMBackendMixin and implement for Huggingface
  3. add the PRM itself that uses this stuff
  4. test :)

stdlib

  1. We should add a Scorer either as a mixin or as a subclass of Requirement.
  2. The Scorer's score method must return a non-None value.
  3. The Scorer must also define a preference ordering (is the goal to min the score or max the score? If the scores are categorical what's the preference ordering?)

Adding this to mellea.stdlib.requirements is justified because it allows implementers of SamplingStrategies to differentiate between "normal" requirements and "ranking" requirements.

Backend

Introduce a new Mixin for PRM stuff:

class PRMBackendMixin:
    def prm_act(action: PRMBlah):
        ...

which then gets added to each backend that supports PRM:

class LocalHFBackend(FormatterBackend, AloraBackendMixin, PRMBackendMixin):

We now need a way for mellea.stdlib.requirements.PRMScorer to use the PRMBackendMixin. That can be done in the Scorer.validate. Now, when PRMScorer.validate(...,b: Backend,...) gets called we can check that b has the PRMBackendMixin and if not raise a NotImplementedError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of second sync convo:

  1. We will add a Scorer:> Requirement as discussed above
  2. Separate the PRM model from the backend entirely: add it as a separate component
  3. Implement PRM Model classes for Huggingface/VLLM etc (in stdlib)
  4. The PRMScorer gets passed this PRM object

@nrfulton
Copy link
Contributor

nrfulton commented Sep 4, 2025

Thanks for the contribution!

The pre-commit checks found some type errors; could you please fix these prior to a code review? For these errors, appropriate fix is probably to just add some ignore's.

FYI: you can run these checks locally by installing the pre-commit hooks. Assuming you have already created a venv and installed Mellea editable (uv pip install -e .), you can then install the pre-commit hooks by running the following commands in the root of your mellea checkout:

uv pip install -e . --group dev && pre-commit install

Once installed, you can still commit over errors using the -n (no-verify) flag; e.g., git commit -a -m 'this may not pass pre-commit checks' -n. However, prior to opening PR for review, ensure that the latest commit does pass the pre-commit checks.

@nrfulton nrfulton changed the title [feat] Best-of-N Sampling with Process Reward Models feat: Best-of-N Sampling with Process Reward Models Sep 4, 2025
@nrfulton
Copy link
Contributor

nrfulton commented Sep 4, 2025

#104 seems related (you can replace "aLoRA" with "LoRA" in the description on that issue -- same idea).

@aashka-trivedi aashka-trivedi marked this pull request as draft September 4, 2025 14:03
@aashka-trivedi aashka-trivedi marked this pull request as ready for review September 12, 2025 17:43
@aashka-trivedi
Copy link
Contributor Author

Hi @nrfulton This branch is ready for review. As discussed, this PR implements the following:

  1. ScorerRequirement class for Requirements that return a score and can be used for Best of N
  2. BestofNSamplingStrategy that needs exactly one ScorerRequirement to be passed, and selects a sample that has the max/min ScorerRequirement score which also satisfies all Requirements
  3. PRM abstract class definition, with support in HuggingFace backends, with specific implementations of Generative and Regression PRMs
  4. An example script that shows how these components should be used

@nrfulton nrfulton merged commit b18e03d into generative-computing:main Sep 18, 2025
6 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants