From d0089d123b93e017246277636d308437f21b0bdb Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Wed, 3 Sep 2025 19:56:01 +0000 Subject: [PATCH 1/6] pass optional input to validate fn --- mellea/stdlib/session.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 3c398db8..dbc7162e 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -300,9 +300,11 @@ def act( else: # Default validation strategy just validates all of the provided requirements. if strategy.validate is None: - strategy.validate = lambda reqs, val_ctx, output: self.validate( - reqs, output=output - ) + strategy.validate = ( + lambda reqs, val_ctx, output, input=None: self.validate( # type: ignore + reqs, output=output, input=input + ) + ) # type: ignore # Default generation strategy just generates from context. if strategy.generate is None: @@ -483,6 +485,7 @@ def validate( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, generate_logs: list[GenerateLog] | None = None, + input: CBlock | None = None, ) -> list[ValidationResult]: """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" # Turn a solitary requirement in to a list of requirements, and then reqify if needed. @@ -492,7 +495,17 @@ def validate( validation_target_ctx = self.ctx else: validation_target_ctx = SimpleContext() - validation_target_ctx.insert(output) + if input is not None: + # some validators may need input as well as output + validation_target_ctx.insert_turn( + ContextTurn( + input, + output, # type: ignore + ), # type: ignore + generate_logs=generate_logs, + ) + else: + validation_target_ctx.insert(output) rvs = [] for requirement in reqs: val_result = requirement.validate( From 424de736eab53590428cfc9a5a1b2058d41ef07f Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Wed, 3 Sep 2025 20:15:25 +0000 Subject: [PATCH 2/6] Best of N Sampling with PRM support --- docs/examples/best_of_n/prm.py | 29 +++ mellea/stdlib/rewards/__init__.py | 0 mellea/stdlib/rewards/prm.py | 293 ++++++++++++++++++++++++++++ mellea/stdlib/rewards/prm_scorer.py | 119 +++++++++++ mellea/stdlib/sampling.py | 172 ++++++++++++++++ 5 files changed, 613 insertions(+) create mode 100644 docs/examples/best_of_n/prm.py create mode 100644 mellea/stdlib/rewards/__init__.py create mode 100644 mellea/stdlib/rewards/prm.py create mode 100644 mellea/stdlib/rewards/prm_scorer.py diff --git a/docs/examples/best_of_n/prm.py b/docs/examples/best_of_n/prm.py new file mode 100644 index 00000000..8be2dd08 --- /dev/null +++ b/docs/examples/best_of_n/prm.py @@ -0,0 +1,29 @@ +"""Example of Using Best of N with PRMs""" + +from docs.examples.helper import w +from mellea import start_session +from mellea.backends.types import ModelOption +from mellea.stdlib.rewards.prm_scorer import PRMScorer +from mellea.stdlib.sampling import BestofNSamplingStrategy + +# create a session using Granite 3.3 8B on Huggingface and a simple context [see below] +m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 1024}) + +# create PRM scorer object +prm = PRMScorer( + model_version="ibm-granite/granite-3.3-8b-lora-math-prm", + prm_type="generative", + correct_token="Y", + generation_prompt="Is this response correct so far (Y/N)?", + step_splitter="\n\n", +) + +# Do Best of N sampling with the PRM scorer +BoN_prm = m.instruct( + "Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?", + strategy=BestofNSamplingStrategy(loop_budget=3, requirements=[prm]), + model_options={"temperature": 0.9, "do_sample": True}, +) + +# print result +print(f"***** BoN ****\n{w(BoN_prm)}\n*******") diff --git a/mellea/stdlib/rewards/__init__.py b/mellea/stdlib/rewards/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mellea/stdlib/rewards/prm.py b/mellea/stdlib/rewards/prm.py new file mode 100644 index 00000000..fe1320e6 --- /dev/null +++ b/mellea/stdlib/rewards/prm.py @@ -0,0 +1,293 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class GenerativePRMForInference(torch.nn.Module): + """ + Class for Generative Process Reward Models for Inference + Uses Huggingface backend to load the model (which is trained using LoRA adapters) + """ + + def __init__( + self, + model_path="ibm-granite/granite-3.3-8b-lora-math-prm", + correct_token="Y", + generation_prompt="Is this response correct so far (Y/N)?", + load_in_bf16=True, + device=None, + ) -> None: + super().__init__() + + if not load_in_bf16: + self.model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto" + ) + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16, device_map="auto" + ) + + if device is not None: + self.model.to(device) + self.device = self.model.device + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer.truncation_side = "left" # prevents truncation from right (default): needed since we always want to have the last step and last generation prompt from the context. + self.correct_token = correct_token + self.correct_token_id = self.tokenizer.encode( + self.correct_token, add_special_tokens=False + )[0] + self.generation_prompt = generation_prompt + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, raw_inputs): + """ + Expects a raw_batch of (questions: List[str], steps: List[List[str]]) + Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores + """ + + # get un-tokenized batch + batches = self.prepare_batch(raw_inputs) + # each element of the batch consists of a list of num_steps messages corresponding to a single input, which we need to handle + all_rewards = [] + all_rewards_per_step = [] + + chat_template_to_turn = self.tokenizer.apply_chat_template( + [{"role": "assistant", "content": self.correct_token}], + tokenize=False, + add_generation_prompt=False, + ) + if "system" in chat_template_to_turn: + if "granite" in self.model.config.model_type.lower(): + # for granite, apply_chat_template also adds system prompt + asst_text = ( + "<|start_of_role|>assistant<|end_of_role|>" + + self.correct_token + + "<|end_of_text|>" + ) + elif "phi" in self.model.config.model_type.lower(): + # phi reasoning also applies the system prompt + asst_text = ( + "<|im_start|>assistant<|im_sep|>" + + self.correct_token + + "<|im_end|>'" + ) + else: + asst_text = chat_template_to_turn + asst_toks = self.tokenizer( + asst_text, add_special_tokens=False, return_tensors="pt" + )["input_ids"][0] + asst_toks_before_correct_token = asst_toks[ + : torch.where(asst_toks == self.correct_token_id)[0].item() + ].tolist() + + # each element in batch contains a question and the response + for i in batches: + batches[i] = batches[i].to(self.model.device) + + with torch.no_grad(): + model_outputs = self.model(**batches) + logits = model_outputs.logits # (bsz, seq_len, vocab_size) + + for batch_idx in range(logits.shape[0]): + per_input_rewards = [] + # for each element in the batch (i.e., each input) + # we need to get logits for all tokens where the token in "Y" (in assistant turn) + # find batch index for assistant turn "Y", not just the correct_token_id + correct_token_indices = torch.where( + batches["input_ids"][batch_idx] == self.correct_token_id + )[0].tolist() + prm_indices = [] + for t_idx in correct_token_indices: + if ( + batches["input_ids"][batch_idx][ + t_idx - len(asst_toks_before_correct_token) : t_idx + ].tolist() + == asst_toks_before_correct_token + ): + prm_indices.append( + t_idx - 1 + ) # the logits for token i predict the token i+1: so, we need to look at the PREVIOUS token logits + + assert len(prm_indices) > 0 + # convert logits to probabilities and get the probability of the correct token id as reward + for prm_idx in prm_indices: + per_input_rewards.append( + self.softmax(logits[batch_idx, prm_idx, :])[ + self.correct_token_id + ].item() + ) + + # aggregate. return final rewards + all_rewards_per_step.append(per_input_rewards) + sum = 0 + for reward in per_input_rewards: + sum += reward + per_input_reward = sum / len(per_input_rewards) + all_rewards.append(per_input_reward) + + return all_rewards, all_rewards_per_step + + def prepare_batch(self, raw_batch): + """ + Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token + prepare_batch() function splits each step into an individual response, and prepares an input batch + prepare batch for forward pass + """ + + questions, list_of_steps = raw_batch + assert len(questions) == len(list_of_steps) + + inputs = [] + for i in range(len(questions)): + user_content = questions[i] + steps = list_of_steps[i] + msgs = [] + for s_idx, step in enumerate(steps): + # apply chat template as expected by RM input + if s_idx == 0: + msgs.append( + { + "role": "user", + "content": user_content + + " " + + step + + " " + + self.generation_prompt, + } + ) + else: + # first add last assistant turn + msgs.append({"role": "assistant", "content": self.correct_token}) + msgs.append( + {"role": "user", "content": step + " " + self.generation_prompt} + ) + + # append the last asst turn + msgs.append({"role": "assistant", "content": self.correct_token}) + + input_message = self.tokenizer.apply_chat_template( + msgs, add_generation_prompt=False, tokenize=False + ) + + inputs.append(input_message) + + return self.tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True + ) + + +class RegressionPRMForInference(torch.nn.Module): + """ + Class for Regression (non-generative) Process Reward Models for Inference + Uses Huggingface backend to load the model + All regression process reward models trained by the GMA team at IBM research use a special step token, + """ + + def __init__( + self, + model_path: str, + step_eos: str = "", + load_in_bf16: bool = True, + device=None, + ) -> None: + super().__init__() + + # Load the model + self.model: AutoModelForCausalLM + if not load_in_bf16: + self.model = AutoModelForCausalLM.from_pretrained( # type: ignore + model_path, device_map="auto" + ) + else: + self.model = AutoModelForCausalLM.from_pretrained( # type: ignore + model_path, torch_dtype=torch.bfloat16, device_map="auto" + ) + self.device = self.model.device + self.config = self.model.config + + # get the token IDs for the step separator token + self.step_eos = step_eos + self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_path) + # self.tokenizer.add_tokens(self.step_eos) + self.step_eos_id = self.tokenizer.encode( + self.step_eos, add_special_tokens=False + )[0] + + # load the PRM head + self.prm_head = torch.nn.Linear( + self.model.config.hidden_size, 2, bias=False, dtype=self.model.dtype + ).to(self.model.device) + state = torch.load(model_path + "/added_params.bin") + self.load_state_dict(state, strict=False) + self.model.eval() + + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, raw_batch): + """ + Expects a raw_batch of (questions: List[str], steps: List[List[str]]) + Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores + """ + + # tokenizes the batch and concatenates the list of steps into a single step-separated response + batch = self.prepare_batch(raw_batch).to(self.device) + + with torch.no_grad(): + model_outputs = self.model(**batch, output_hidden_states=True) + # all logits + all_prm_logits = self.prm_head(model_outputs["hidden_states"][-1]).squeeze( + -1 + ) + + # get logits for each end of step i.e. logits for step_eos positions in the input + prm_probs = [] + rewards = [] + for idx in range(all_prm_logits.shape[0]): + prm_indices = torch.where(batch["input_ids"][idx] == self.step_eos_id)[0] + if prm_indices.shape[0] == 0: + # no match found-- model did not produce outputs in correct step-wise format + prm_probs.append([None]) + reward = None + else: + # head produces two logits, the second one is the logit for the correct answer + # convert logits to probabilities using softmax + # return list of floats instead of list of tensors + prm_probs_per_sample = [ + t.item() + for t in self.softmax(all_prm_logits[idx][prm_indices])[:, 1] + ] + prm_probs.append(prm_probs_per_sample) + + reward = sum(prm_probs_per_sample) / len(prm_probs_per_sample) + rewards.append(reward) + + return rewards, prm_probs + + def prepare_batch(self, raw_batch): + """ + Tokenize and prepare batch for forward pass + Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token + """ + + questions, list_of_steps = raw_batch + assert len(questions) == len(list_of_steps) + + inputs = [] + for i in range(len(questions)): + text_with_steps_marked = "" + + for step in list_of_steps[i]: + text_with_steps_marked += f"{step} {self.step_eos}" + + message = [ + {"role": "user", "content": questions[i]}, + {"role": "assistant", "content": text_with_steps_marked}, + ] + input = self.tokenizer.apply_chat_template(message, tokenize=False) + inputs.append(input) + + # tokenize data for the RM + batch = self.tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True + ) + return batch diff --git a/mellea/stdlib/rewards/prm_scorer.py b/mellea/stdlib/rewards/prm_scorer.py new file mode 100644 index 00000000..998dec45 --- /dev/null +++ b/mellea/stdlib/rewards/prm_scorer.py @@ -0,0 +1,119 @@ +import torch + +from mellea.stdlib.base import CBlock, Context +from mellea.stdlib.chat import Message +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.rewards.prm import ( + GenerativePRMForInference, + RegressionPRMForInference, +) + + +class PRMScorer(Requirement): + """A process reward model scorer based on local huggingface backend.""" + + def __init__( + self, + *, + model_version: str = "ibm-granite/granite-3.3-8b-lora-math-prm", + device: str | None = None, + step_splitter="\n\n", + prm_type: str = "generative", + **prm_kwargs, + ): + """ + + Args: + model_version: The version of the model, defaults to "ibm-granite/granite-3.3-8b-lora-math-prm". + device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. + correct_token: PRM generated token that indicates step is correct + generation_prompt: Generation prompt required for the PRM scorer + step_splitter: string on which assistant response is split into steps + prm_type: type of prm tobe used. must be either `generative` or `regression` + prm_kwargs: args for PRM. For Generative, pass `correct_token`, `generation_prompt`. For Regression, pass `step_token` + """ + super().__init__(check_only=True, validation_fn=lambda c: self._prm_validate(c)) + + self._model_version = model_version + + # auto-device if not more specific + self._device = device + if device is None: + device_name: str = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) + assert device_name is not None + self._device = torch.device(device_name) # type: ignore + + self.step_splitter = step_splitter + assert prm_type.lower() in ["generative", "regression"], ( + "prm_type must be either generative or regression" + ) + self.prm_type = prm_type.lower() + self.prm_kwargs = prm_kwargs + + def _prm_validate(self, ctx: Context): + """ + Returns PRM score of last turn of context + """ + last_turn = ctx.last_turn() + assert last_turn is not None + + # This requirement can handle only complete turns with both + # a user message and an assistant message + + assert last_turn.model_input is not None and last_turn.output is not None + assert last_turn.output.value is not None + + user_msg = last_turn.model_input + + # Handle the variety of possible user input. + if isinstance(user_msg, CBlock) and user_msg.value is not None: + user_query = user_msg.value + elif isinstance(user_msg, Message) and user_msg.content != "": + user_query = user_msg.content + else: + user_query = str(user_msg) + + assistant_content = last_turn.output.value + + # convert assistant message into a list of steps + list_of_steps = [ + step.strip() + for step in assistant_content.split(self.step_splitter) + if step.strip != "" + ] + + # Load model + model: GenerativePRMForInference | RegressionPRMForInference + if self.prm_type == "generative": + model = GenerativePRMForInference( + model_path=self._model_version, + load_in_bf16=True, + device=self._device, + **self.prm_kwargs, + ) + model.to(self._device) + elif self.prm_type == "regression": + model = RegressionPRMForInference( + model_path=self._model_version, + load_in_bf16=True, + device=self._device, + **self.prm_kwargs, + ) # type: ignore[no-redef] + else: + raise NotImplementedError + + rewards, rewards_per_step = model(([user_query], [list_of_steps])) + + # return single reward item for the response + assert len(rewards) == 1 + + # offload and delete model before returning rewards + del model + + return ValidationResult(result=True, reason=None, score=rewards[0]) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index ee6ab431..cb0da644 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -381,3 +381,175 @@ def repair( ) return next_action + + +class BestofNSamplingStrategy(BaseSamplingStrategy): + """ + Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer + """ + + def __init__( + self, + *, + loop_budget: int = 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement], + ): + """Initialize a new instance of the class with default parameters. + + Args: + loop_budget: Number of times to iterate through the process. Must be greater than 0. + validate: Function to validate the results against requirements. If None, validation is provided later through setter. + generate: Function to generate new model output thunks. If None, generate is provided later through setter. + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. + + Raises: + AssertionError: If loop_budget is not greater than 0. + AssertionError: If there is more/less than one requirements + """ + super().__init__( + loop_budget=loop_budget, + validate=validate, + generate=generate, + requirements=requirements, + ) + + self.requirements = requirements + assert len(self.requirements) == 1 + + def sample( + self, + action: Component, + context: Context, + requirements: list[Requirement], + *, + show_progress: bool = True, + generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, + ) -> SamplingResult: + """This method performs a sampling operation based on the given instruction. + + Args: + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + generate_logs: If provided, the generations will be logged. + requirements: List of requirements to test against (merged with global requirements). + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + + Returns: + SamplingResult: A result object indicating the success or failure of the sampling process. + + Raises: + AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. + """ + assert self.validate is not None, "Validation must be provided." + assert self.generate is not None, "Generate must be provided." + + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + + flog = FancyLogger.get_logger() + + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sampled_val_scores: list[float] = [] + + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO + + reqs = [] + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + + reqs = list(set(reqs)) + assert len(reqs) == 1, "Bets of n only supports one requirement" + + loop_count = 0 + loop_budget_range_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore + ) + + new_action = deepcopy(action) + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # run a generation pass + result = self.generate(new_action, ctx, generate_logs) + + # validation pass + # action has user turn + val_scores = self.validate( + reqs, + validation_ctx, + result, + input=action._description, # type: ignore + ) + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + # only a single requirement is used for BestofNSampling + sampled_val_scores.append( + val_scores[0]._score # type: ignore + ) + + best_result, best_score = max( + zip(sampled_results, sampled_val_scores), key=lambda x: x[1] + ) + + return SamplingResult( + best_result, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # simply returns the first attempt if all loops fail + return 0 + + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + pa = past_actions[-1] + if isinstance(pa, Instruction): + last_failed_reqs: list[Requirement] = [ + s[0] for s in past_val[-1] if not s[1] + ] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + return pa.copy_and_repair( + repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" + ) + return past_actions[-1] From f5b050507fe7009761db691464f00041aa112d9b Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Thu, 4 Sep 2025 16:36:13 +0000 Subject: [PATCH 3/6] implement ScorerRequirement class --- mellea/stdlib/requirement.py | 78 +++++++++++++++++++++++++++++ mellea/stdlib/rewards/prm_scorer.py | 12 +++-- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 1aa849fa..495092cb 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -178,6 +178,84 @@ def __init__(self, description: str, alora: Alora | None = None): self.alora = alora +class ScorerRequirement(Requirement): + """A requirement that always returns a non-None score. The scorer must also define a preference ordering to indicate whether the goal is to maximize or minimize the score.""" + + def __init__( + self, + description: str | None = None, + validation_fn: Callable[[Context], ValidationResult] | None = None, + preference_ordering: str = "max", + *, + output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, + check_only: bool = False, + ): + """A requirement that is validated by an ALora. + + Args: + description: See `Requirement.__init__` + validation_fn: If provided, this function will be executed instead of using LLM-as-a-Judge. This function must return a valid score + preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min". Defaults to None + output_to_bool: See `Requirement.__init__` + check_only: See `Requirement.__init__` + """ + super().__init__( + description, + validation_fn=validation_fn, + output_to_bool=output_to_bool, + check_only=check_only, + ) + + if preference_ordering.lower() not in ["max", "min"]: + raise NotImplementedError + self.preference_ordering: str = preference_ordering.lower() + + def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] | None = None, + ) -> ValidationResult: + """Chooses the appropriate validation strategy and applies that strategy. Asserts that the returned ValidationResult has a valid score.""" + if self.validation_fn is not None: + # Python validation strategy + validation_result = self.validation_fn(ctx) + assert validation_result._score is not None, ( + "ScorerRequirement must have a score that is not None" + ) + return validation_result + else: + # LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch. + # For ScorerRequirement, provide score of 1 for result=True, 0 for result=False + assert self.output_to_bool is not None + last_output = ctx.last_output() + assert isinstance(last_output, ModelOutputThunk), ( + " Context has no appropriate last output" + ) + + # Create a copy of the requirement that holds the output + # and its template gets populated with the output correctly. + req_copy = copy(self) + req_copy._output = last_output.value + llm_as_a_judge_result = backend.generate_from_context( + req_copy, + ctx, + format=format, + model_options=model_options, + generate_logs=generate_logs, + ) + result = self.output_to_bool(llm_as_a_judge_result) + + return ValidationResult( + result=result, + reason=llm_as_a_judge_result.value, + score=1 if result else 0, + ) + + def reqify(r: str | Requirement) -> Requirement: """Maps strings to Requirements. diff --git a/mellea/stdlib/rewards/prm_scorer.py b/mellea/stdlib/rewards/prm_scorer.py index 998dec45..9d56e62d 100644 --- a/mellea/stdlib/rewards/prm_scorer.py +++ b/mellea/stdlib/rewards/prm_scorer.py @@ -2,20 +2,21 @@ from mellea.stdlib.base import CBlock, Context from mellea.stdlib.chat import Message -from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.requirement import ScorerRequirement, ValidationResult from mellea.stdlib.rewards.prm import ( GenerativePRMForInference, RegressionPRMForInference, ) -class PRMScorer(Requirement): +class PRMScorer(ScorerRequirement): """A process reward model scorer based on local huggingface backend.""" def __init__( self, *, model_version: str = "ibm-granite/granite-3.3-8b-lora-math-prm", + preference_ordering: str = "max", device: str | None = None, step_splitter="\n\n", prm_type: str = "generative", @@ -25,6 +26,7 @@ def __init__( Args: model_version: The version of the model, defaults to "ibm-granite/granite-3.3-8b-lora-math-prm". + preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min" device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. correct_token: PRM generated token that indicates step is correct generation_prompt: Generation prompt required for the PRM scorer @@ -32,7 +34,11 @@ def __init__( prm_type: type of prm tobe used. must be either `generative` or `regression` prm_kwargs: args for PRM. For Generative, pass `correct_token`, `generation_prompt`. For Regression, pass `step_token` """ - super().__init__(check_only=True, validation_fn=lambda c: self._prm_validate(c)) + super().__init__( + check_only=True, + validation_fn=lambda c: self._prm_validate(c), + preference_ordering=preference_ordering, + ) self._model_version = model_version From 7e4be67b7954535351f7ffc8cdc0d960266d5f27 Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Thu, 4 Sep 2025 19:09:41 +0000 Subject: [PATCH 4/6] BestofNSampling with support for multiple requirements --- docs/examples/best_of_n/prm.py | 5 +- mellea/stdlib/sampling.py | 143 +++++++++++++++++++++------------ 2 files changed, 94 insertions(+), 54 deletions(-) diff --git a/docs/examples/best_of_n/prm.py b/docs/examples/best_of_n/prm.py index 8be2dd08..95321b60 100644 --- a/docs/examples/best_of_n/prm.py +++ b/docs/examples/best_of_n/prm.py @@ -18,11 +18,12 @@ step_splitter="\n\n", ) -# Do Best of N sampling with the PRM scorer +# Do Best of N sampling with the PRM scorer and an additional requirement BoN_prm = m.instruct( "Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?", - strategy=BestofNSamplingStrategy(loop_budget=3, requirements=[prm]), + strategy=BestofNSamplingStrategy(loop_budget=3), model_options={"temperature": 0.9, "do_sample": True}, + requirements=["provide final answer like 'Final Answer:'", prm], ) # print result diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index cb0da644..a771bc19 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -19,7 +19,7 @@ ) from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction -from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult class SamplingResult(CBlock): @@ -388,40 +388,6 @@ class BestofNSamplingStrategy(BaseSamplingStrategy): Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer """ - def __init__( - self, - *, - loop_budget: int = 1, - validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] - | None = None, - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] - | None - ) = None, - requirements: list[Requirement], - ): - """Initialize a new instance of the class with default parameters. - - Args: - loop_budget: Number of times to iterate through the process. Must be greater than 0. - validate: Function to validate the results against requirements. If None, validation is provided later through setter. - generate: Function to generate new model output thunks. If None, generate is provided later through setter. - requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. - - Raises: - AssertionError: If loop_budget is not greater than 0. - AssertionError: If there is more/less than one requirements - """ - super().__init__( - loop_budget=loop_budget, - validate=validate, - generate=generate, - requirements=requirements, - ) - - self.requirements = requirements - assert len(self.requirements) == 1 - def sample( self, action: Component, @@ -461,7 +427,12 @@ def sample( sampled_results: list[ModelOutputThunk] = [] sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] sampled_actions: list[Component] = [] - sampled_val_scores: list[float] = [] + + successful_sampled_results: list[ModelOutputThunk] = [] + successful_sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + successful_sampled_actions: list[Component] = [] + + # sampled_val_scores: list[float] = [] # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. @@ -474,7 +445,17 @@ def sample( reqs += requirements reqs = list(set(reqs)) - assert len(reqs) == 1, "Bets of n only supports one requirement" + + # check that there is exactly one ScorerRequirement + scorer_requirements = 0 + for req in reqs: + # strict typecheck for scorer requirement + if isinstance(req, ScorerRequirement): + scorer_requirements += 1 + + assert scorer_requirements == 1, ( + "BestOfNSamplingStrategy requires exactly one ScorerRequirement" + ) loop_count = 0 loop_budget_range_iterator = ( @@ -508,22 +489,69 @@ def sample( sampled_results.append(result) sampled_scores.append(constraint_scores) sampled_actions.append(new_action) - # only a single requirement is used for BestofNSampling - sampled_val_scores.append( - val_scores[0]._score # type: ignore + + # check if requirements pass else repair and re-sample + # if all vals are true, save it and continue to get next sample + if all(bool(s[1]) for s in constraint_scores): + flog.info("SUCCESS") + successful_sampled_results.append(result) + successful_sampled_scores.append(constraint_scores) + successful_sampled_actions.append(new_action) + + else: + # log partial success and continue + count_valid = len([s for s in constraint_scores if bool(s[1])]) + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + + # If we did not pass all constraints, update the instruction and try again. + new_action = self.repair( + ctx, sampled_actions, sampled_results, sampled_scores + ) + + # find max reward amongst results for which all requirements have passed + if len(successful_sampled_scores) > 0: + scores: list[float] = [] + + for sample in successful_sampled_scores: + for req, val_score in sample: + if isinstance(req, ScorerRequirement): + assert val_score._score is not None + scores.append(val_score._score) + + assert len(successful_sampled_results) == len(scores) + + best_result, best_score = max( + zip(successful_sampled_results, scores), key=lambda x: x[1] ) - best_result, best_score = max( - zip(sampled_results, sampled_val_scores), key=lambda x: x[1] - ) + return SamplingResult( + best_result, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) - return SamplingResult( - best_result, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - ) + # if all failures, call select from failure + else: + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." + ) + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." + ) + return SamplingResult( + sampled_results[best_failed_index], + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) @staticmethod def select_from_failure( @@ -531,8 +559,19 @@ def select_from_failure( sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], ) -> int: - # simply returns the first attempt if all loops fail - return 0 + # select attempt with highest ScoreRequirementScore if all loops fail + + scores: list[float | None] = [] + + for sample in sampled_val: + for req, val_score in sample: + if isinstance(req, ScorerRequirement): + assert val_score._score is not None + scores.append(val_score._score) + + assert len(sampled_results) == len(scores) + + return scores.index(max(scores)) # type: ignore @staticmethod def repair( From 4945002d0f196c11253372e0d9d322c35941c15f Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Fri, 12 Sep 2025 17:37:20 +0000 Subject: [PATCH 5/6] implement PRM classes in backend --- docs/examples/best_of_n/prm.py | 28 +- mellea/backends/huggingface.py | 54 ++++ .../process_reward_models/__init__.py | 24 ++ .../huggingface/__init__.py | 1 + .../process_reward_models/huggingface/prms.py | 254 +++++++++++++++ mellea/stdlib/rewards/prm.py | 293 ------------------ mellea/stdlib/rewards/prm_scorer.py | 78 +---- 7 files changed, 358 insertions(+), 374 deletions(-) create mode 100644 mellea/backends/process_reward_models/__init__.py create mode 100644 mellea/backends/process_reward_models/huggingface/__init__.py create mode 100644 mellea/backends/process_reward_models/huggingface/prms.py delete mode 100644 mellea/stdlib/rewards/prm.py diff --git a/docs/examples/best_of_n/prm.py b/docs/examples/best_of_n/prm.py index 95321b60..a945c625 100644 --- a/docs/examples/best_of_n/prm.py +++ b/docs/examples/best_of_n/prm.py @@ -2,22 +2,34 @@ from docs.examples.helper import w from mellea import start_session +from mellea.backends.process_reward_models.huggingface.prms import ( + HFGenerativePRM, + HFRegressionPRM, +) from mellea.backends.types import ModelOption from mellea.stdlib.rewards.prm_scorer import PRMScorer from mellea.stdlib.sampling import BestofNSamplingStrategy -# create a session using Granite 3.3 8B on Huggingface and a simple context [see below] -m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 1024}) +# create a session for the generator using Granite 3.3 8B on Huggingface and a simple context [see below] +m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 512}) -# create PRM scorer object -prm = PRMScorer( - model_version="ibm-granite/granite-3.3-8b-lora-math-prm", - prm_type="generative", - correct_token="Y", +# initialize the PRM model +prm_model = HFGenerativePRM( + model_name_or_path="ibm-granite/granite-3.3-8b-lora-math-prm", + score_token="Y", generation_prompt="Is this response correct so far (Y/N)?", - step_splitter="\n\n", + step_separator="\n\n", ) +# # can also initialize a Regression PRM model +# prm_model = HFRegressionPRM( +# model_name_or_path = "granite-3.3-8b-math-prm-regression", +# score_token= "", +# step_separator= "\n\n") + +# create PRM scorer object +prm = PRMScorer(prm_model=prm_model, preference_ordering="max") + # Do Best of N sampling with the PRM scorer and an additional requirement BoN_prm = m.instruct( "Sarah has 12 apples. She gives 5 of them to her friend. How many apples does Sarah have left?", diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 360437bf..ee51994e 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -30,6 +30,7 @@ from mellea.backends.cache import Cache, SimpleLRUCache from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier +from mellea.backends.process_reward_models import PRM from mellea.backends.tools import ( add_tools_from_context_actions, add_tools_from_model_options, @@ -670,3 +671,56 @@ def __init__( self._generation_prompt_tokens = self._backend._tokenizer( self._generation_prompt, return_tensors="pt" ).to(self._backend._device) + + +class HFProcessRewardModel(PRM, abc.ABC): + def __init__( + self, model_name_or_path: str, score_token: str, device: str | None = None + ): + """Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + + Args: + model_name_or_path (str): A local path to PRM or a huggingface PRM + score_token (str): token who's logits correspond to the PRM score. Can be a step demarker (for non-generative PRMs) or a correctness indicator (for generative PRMs) + device (str): device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. + """ + super().__init__(model_name_or_path) + + # auto-device if not more specific + self._device = device + if device is None: + device_name: str = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) + assert device_name is not None + self._device = torch.device(device_name) # type: ignore + + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + self.model_name_or_path, torch_dtype=torch.bfloat16 + ) + self.model.to(self._device) # type: ignore + self.model.eval() + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + + self._score_token = score_token + self._score_token_id = self.tokenizer.encode( + self._score_token, add_special_tokens=False + )[0] + + def stepify(self, content: str, step_separator: str) -> list[str]: + """Splits the assistant response into steps to score + + Args: + content: assistant response to score + step_separator: string on which to separate the content into steps + """ + + # convert assistant message into a list of steps + list_of_steps = [ + step.strip() for step in content.split(step_separator) if step.strip != "" + ] + return list_of_steps diff --git a/mellea/backends/process_reward_models/__init__.py b/mellea/backends/process_reward_models/__init__.py new file mode 100644 index 00000000..ae911f98 --- /dev/null +++ b/mellea/backends/process_reward_models/__init__.py @@ -0,0 +1,24 @@ +"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers)""" + +import abc + + +class PRM(abc.ABC): + def __init__(self, model_name_or_path): + # Leave implementation of model to inheriting class + self.model_name_or_path = model_name_or_path + + @abc.abstractmethod + def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: + """Returns a final score and per-step score to the input of the model""" + ... + + @abc.abstractmethod + def stepify(self, response: str, step_separator: str) -> list[str]: + """Splits the assistant response into steps to score + + Args: + response: assistant response to score + step_separator: string on which to separate the response into steps + """ + ... diff --git a/mellea/backends/process_reward_models/huggingface/__init__.py b/mellea/backends/process_reward_models/huggingface/__init__.py new file mode 100644 index 00000000..3b259046 --- /dev/null +++ b/mellea/backends/process_reward_models/huggingface/__init__.py @@ -0,0 +1 @@ +"""Process Reward Model Implementations with Huggingface backends""" diff --git a/mellea/backends/process_reward_models/huggingface/prms.py b/mellea/backends/process_reward_models/huggingface/prms.py new file mode 100644 index 00000000..2bac7afb --- /dev/null +++ b/mellea/backends/process_reward_models/huggingface/prms.py @@ -0,0 +1,254 @@ +import torch +from transformers.tokenization_utils_base import BatchEncoding + +from mellea.backends.huggingface import HFProcessRewardModel + + +class HFGenerativePRM(HFProcessRewardModel): + def __init__( + self, + model_name_or_path: str = "ibm-granite/granite-3.3-8b-lora-math-prm", + score_token: str = "Y", + device: str | None = None, + generation_prompt: str = "Is this response correct so far (Y/N)?", + step_separator: str = "\n\n", + ): + """Initialize a Generative PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + + Args: + model_name_or_path (str): A local path to PRM or a huggingface PRM + score_token (str): token who's logits correspond to the PRM score. Usually is a correctness indicator (for generative PRMs) + device (str): pointer to device + generation_prompt (str): Optional prompt to be added before generation + step_separator (str): string on which to separate the content into steps + """ + super().__init__(model_name_or_path, score_token, device) + self.generation_prompt = ( + generation_prompt if generation_prompt is not None else "" + ) + self.step_separator = step_separator + self.softmax = torch.nn.Softmax(dim=-1) + + def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: + """Returns a final and per-step score for a given input query and response + + Args: + query (str): User query + response (str): Assistant Response to score + """ + + list_of_steps = self.stepify(response, self.step_separator) + # get tokenized batch + batches = self.prepare_inputs(query, list_of_steps) + all_rewards = [] + all_rewards_per_step = [] + + # find the chat turn where assistant message starts to find the correct placement of the score token + # add empty system prompt to prevent model from adding its own system prompt + chat_template_to_turn = self.tokenizer.apply_chat_template( + [ + {"role": "system", "content": ""}, + {"role": "assistant", "content": self._score_token}, + ], + tokenize=False, + add_generation_prompt=False, + ) + # removing the system prompt by finding the assistant turn, which usually starts like <|..|>assistant<|..> + asst_text = chat_template_to_turn[chat_template_to_turn.find(">assistant<") :][ + 1: + ] + asst_toks = self.tokenizer( + asst_text, add_special_tokens=False, return_tensors="pt" + )["input_ids"][0] + asst_toks_before_correct_token = asst_toks[ + : torch.where(asst_toks == self._score_token_id)[ + 0 + ].item() # type: ignore + ].tolist() # type : ignore + + # move each item of the batch to the device + for i in batches: + batches[i] = batches[i].to(self.model.device) + + with torch.no_grad(): + model_outputs = self.model(**batches) + logits = model_outputs.logits # (bsz, seq_len, vocab_size) + + for batch_idx in range(logits.shape[0]): + per_input_rewards = [] + # for each element in the batch (i.e., each input) + # we need to get logits for all tokens where the token is self._score_token (in assistant turn) + # find batch index for **assistant** turn is self._score_token, not just the self._score_token_id + correct_token_indices = torch.where( + batches["input_ids"][batch_idx] == self._score_token_id + )[0].tolist() + prm_indices = [] + for t_idx in correct_token_indices: + if ( + batches["input_ids"][batch_idx][ + t_idx - len(asst_toks_before_correct_token) : t_idx + ].tolist() + == asst_toks_before_correct_token + ): + prm_indices.append( + t_idx - 1 + ) # the logits for token i predict the token i+1: so, we need to look at the **previous** token logits + + assert len(prm_indices) > 0 + # convert logits to probabilities and get the probability of the correct token id as reward + for prm_idx in prm_indices: + per_input_rewards.append( + self.softmax(logits[batch_idx, prm_idx, :])[ + self._score_token_id + ].item() + ) + + # aggregate. return final rewards + all_rewards_per_step.append(per_input_rewards) + sum = 0 + for reward in per_input_rewards: + sum += reward + per_input_reward = sum / len(per_input_rewards) + all_rewards.append(per_input_reward) + + return all_rewards, all_rewards_per_step + + def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: + """Prepare the inputs for inference with the model + + Args: + user_content (str): the user query + steps (List(str)): assistant response, broken down into steps + """ + msgs = [] + for s_idx, step in enumerate(steps): + # apply chat template as expected by the reward model + # rewards are calculated from the logit of self._score_token as produced by the assistant + if s_idx == 0: + msgs.append( + { + "role": "user", + "content": user_content + + " " + + step + + " " + + self.generation_prompt, + } + ) + else: + # first add last assistant turn + msgs.append({"role": "assistant", "content": self._score_token}) + msgs.append( + {"role": "user", "content": step + " " + self.generation_prompt} + ) + + # append last assistant turn + msgs.append({"role": "assistant", "content": self._score_token}) + input_message = self.tokenizer.apply_chat_template( + msgs, add_generation_prompt=False, tokenize=False + ) + return self.tokenizer( + [input_message], return_tensors="pt", padding=True, truncation=True + ) + + +class HFRegressionPRM(HFProcessRewardModel): + def __init__( + self, + model_name_or_path: str, + score_token: str = "", + device: str | None = None, + step_separator: str = "\n\n", + ): + """Initialize a Regression PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + + Args: + model_name_or_path (str): A local path to PRM or a huggingface PRM + score_token (str): token who's logits correspond to the PRM score. Usually is a step demarker (for non-generative PRMs) + backend (LocalHFBackend): Mained as a pointer to the backend to which this this PRM is attached. + step_separator (str): string on which to separate the input content into steps + """ + super().__init__(model_name_or_path, score_token, device) + + # initialize PRM head + self.prm_head = torch.nn.Linear( + self.model.config.hidden_size, 2, bias=False, dtype=self.model.dtype + ).to(self.model.device) + + state = torch.load(model_name_or_path + "/added_params.bin") + # need to do this-- we save model dict as `prm_head.weight` during training + new_state_dict = {} + for k, v in state.items(): + new_k = k.replace("prm_head.", "") + new_state_dict[new_k] = v + + self.prm_head.load_state_dict(new_state_dict) + self.prm_head.eval() + + self.step_separator = step_separator + self.softmax = torch.nn.Softmax(dim=-1) + + def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: + """Returns a final and per-step score for a given input query and response + + Args: + query (str): User query + response (str): Assistant Response to score + """ + + list_of_steps = self.stepify(response, self.step_separator) + # tokenizes the batch and concatenates the list of steps into a single step-separated response + batch = self.prepare_inputs(query, list_of_steps) + # move each item of the batch to the device + for i in batch: + batch[i] = batch[i].to(self.model.device) + + with torch.no_grad(): + model_outputs = self.model(**batch, output_hidden_states=True) + # all logits + all_prm_logits = self.prm_head(model_outputs["hidden_states"][-1]).squeeze( + -1 + ) + + # get logits for each end of step i.e. logits for step_eos positions in the input + prm_probs = [] + rewards = [] + for idx in range(all_prm_logits.shape[0]): + prm_indices = torch.where(batch["input_ids"][idx] == self._score_token_id)[ + 0 + ] + assert prm_indices.shape[0] > 0 + # head produces two logits, the second one is the logit for the correct answer + # convert logits to probabilities using softmax + # return list of floats instead of list of tensors + prm_probs_per_sample = [ + t.item() for t in self.softmax(all_prm_logits[idx][prm_indices])[:, 1] + ] + prm_probs.append(prm_probs_per_sample) + + reward = sum(prm_probs_per_sample) / len(prm_probs_per_sample) + rewards.append(reward) + + return rewards, prm_probs + + def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: + """Prepare the inputs for inference with the model + + Args: + user_content (str): the user query + steps (List(str)): assistant response, broken down into steps + """ + text_with_steps_marked = "" + + for step in steps: + text_with_steps_marked += f"{step} {self._score_token}" + + message = [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": text_with_steps_marked}, + ] + input_message = self.tokenizer.apply_chat_template(message, tokenize=False) + + return self.tokenizer( + [input_message], return_tensors="pt", padding=True, truncation=True + ) diff --git a/mellea/stdlib/rewards/prm.py b/mellea/stdlib/rewards/prm.py deleted file mode 100644 index fe1320e6..00000000 --- a/mellea/stdlib/rewards/prm.py +++ /dev/null @@ -1,293 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - - -class GenerativePRMForInference(torch.nn.Module): - """ - Class for Generative Process Reward Models for Inference - Uses Huggingface backend to load the model (which is trained using LoRA adapters) - """ - - def __init__( - self, - model_path="ibm-granite/granite-3.3-8b-lora-math-prm", - correct_token="Y", - generation_prompt="Is this response correct so far (Y/N)?", - load_in_bf16=True, - device=None, - ) -> None: - super().__init__() - - if not load_in_bf16: - self.model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto" - ) - else: - self.model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.bfloat16, device_map="auto" - ) - - if device is not None: - self.model.to(device) - self.device = self.model.device - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - self.tokenizer.truncation_side = "left" # prevents truncation from right (default): needed since we always want to have the last step and last generation prompt from the context. - self.correct_token = correct_token - self.correct_token_id = self.tokenizer.encode( - self.correct_token, add_special_tokens=False - )[0] - self.generation_prompt = generation_prompt - self.softmax = torch.nn.Softmax(dim=-1) - - def forward(self, raw_inputs): - """ - Expects a raw_batch of (questions: List[str], steps: List[List[str]]) - Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores - """ - - # get un-tokenized batch - batches = self.prepare_batch(raw_inputs) - # each element of the batch consists of a list of num_steps messages corresponding to a single input, which we need to handle - all_rewards = [] - all_rewards_per_step = [] - - chat_template_to_turn = self.tokenizer.apply_chat_template( - [{"role": "assistant", "content": self.correct_token}], - tokenize=False, - add_generation_prompt=False, - ) - if "system" in chat_template_to_turn: - if "granite" in self.model.config.model_type.lower(): - # for granite, apply_chat_template also adds system prompt - asst_text = ( - "<|start_of_role|>assistant<|end_of_role|>" - + self.correct_token - + "<|end_of_text|>" - ) - elif "phi" in self.model.config.model_type.lower(): - # phi reasoning also applies the system prompt - asst_text = ( - "<|im_start|>assistant<|im_sep|>" - + self.correct_token - + "<|im_end|>'" - ) - else: - asst_text = chat_template_to_turn - asst_toks = self.tokenizer( - asst_text, add_special_tokens=False, return_tensors="pt" - )["input_ids"][0] - asst_toks_before_correct_token = asst_toks[ - : torch.where(asst_toks == self.correct_token_id)[0].item() - ].tolist() - - # each element in batch contains a question and the response - for i in batches: - batches[i] = batches[i].to(self.model.device) - - with torch.no_grad(): - model_outputs = self.model(**batches) - logits = model_outputs.logits # (bsz, seq_len, vocab_size) - - for batch_idx in range(logits.shape[0]): - per_input_rewards = [] - # for each element in the batch (i.e., each input) - # we need to get logits for all tokens where the token in "Y" (in assistant turn) - # find batch index for assistant turn "Y", not just the correct_token_id - correct_token_indices = torch.where( - batches["input_ids"][batch_idx] == self.correct_token_id - )[0].tolist() - prm_indices = [] - for t_idx in correct_token_indices: - if ( - batches["input_ids"][batch_idx][ - t_idx - len(asst_toks_before_correct_token) : t_idx - ].tolist() - == asst_toks_before_correct_token - ): - prm_indices.append( - t_idx - 1 - ) # the logits for token i predict the token i+1: so, we need to look at the PREVIOUS token logits - - assert len(prm_indices) > 0 - # convert logits to probabilities and get the probability of the correct token id as reward - for prm_idx in prm_indices: - per_input_rewards.append( - self.softmax(logits[batch_idx, prm_idx, :])[ - self.correct_token_id - ].item() - ) - - # aggregate. return final rewards - all_rewards_per_step.append(per_input_rewards) - sum = 0 - for reward in per_input_rewards: - sum += reward - per_input_reward = sum / len(per_input_rewards) - all_rewards.append(per_input_reward) - - return all_rewards, all_rewards_per_step - - def prepare_batch(self, raw_batch): - """ - Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token - prepare_batch() function splits each step into an individual response, and prepares an input batch - prepare batch for forward pass - """ - - questions, list_of_steps = raw_batch - assert len(questions) == len(list_of_steps) - - inputs = [] - for i in range(len(questions)): - user_content = questions[i] - steps = list_of_steps[i] - msgs = [] - for s_idx, step in enumerate(steps): - # apply chat template as expected by RM input - if s_idx == 0: - msgs.append( - { - "role": "user", - "content": user_content - + " " - + step - + " " - + self.generation_prompt, - } - ) - else: - # first add last assistant turn - msgs.append({"role": "assistant", "content": self.correct_token}) - msgs.append( - {"role": "user", "content": step + " " + self.generation_prompt} - ) - - # append the last asst turn - msgs.append({"role": "assistant", "content": self.correct_token}) - - input_message = self.tokenizer.apply_chat_template( - msgs, add_generation_prompt=False, tokenize=False - ) - - inputs.append(input_message) - - return self.tokenizer( - inputs, return_tensors="pt", padding=True, truncation=True - ) - - -class RegressionPRMForInference(torch.nn.Module): - """ - Class for Regression (non-generative) Process Reward Models for Inference - Uses Huggingface backend to load the model - All regression process reward models trained by the GMA team at IBM research use a special step token, - """ - - def __init__( - self, - model_path: str, - step_eos: str = "", - load_in_bf16: bool = True, - device=None, - ) -> None: - super().__init__() - - # Load the model - self.model: AutoModelForCausalLM - if not load_in_bf16: - self.model = AutoModelForCausalLM.from_pretrained( # type: ignore - model_path, device_map="auto" - ) - else: - self.model = AutoModelForCausalLM.from_pretrained( # type: ignore - model_path, torch_dtype=torch.bfloat16, device_map="auto" - ) - self.device = self.model.device - self.config = self.model.config - - # get the token IDs for the step separator token - self.step_eos = step_eos - self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_path) - # self.tokenizer.add_tokens(self.step_eos) - self.step_eos_id = self.tokenizer.encode( - self.step_eos, add_special_tokens=False - )[0] - - # load the PRM head - self.prm_head = torch.nn.Linear( - self.model.config.hidden_size, 2, bias=False, dtype=self.model.dtype - ).to(self.model.device) - state = torch.load(model_path + "/added_params.bin") - self.load_state_dict(state, strict=False) - self.model.eval() - - self.softmax = torch.nn.Softmax(dim=-1) - - def forward(self, raw_batch): - """ - Expects a raw_batch of (questions: List[str], steps: List[List[str]]) - Return the aggregated score for each problem (i.e., the average of the per-step scores), along with the per-step scores - """ - - # tokenizes the batch and concatenates the list of steps into a single step-separated response - batch = self.prepare_batch(raw_batch).to(self.device) - - with torch.no_grad(): - model_outputs = self.model(**batch, output_hidden_states=True) - # all logits - all_prm_logits = self.prm_head(model_outputs["hidden_states"][-1]).squeeze( - -1 - ) - - # get logits for each end of step i.e. logits for step_eos positions in the input - prm_probs = [] - rewards = [] - for idx in range(all_prm_logits.shape[0]): - prm_indices = torch.where(batch["input_ids"][idx] == self.step_eos_id)[0] - if prm_indices.shape[0] == 0: - # no match found-- model did not produce outputs in correct step-wise format - prm_probs.append([None]) - reward = None - else: - # head produces two logits, the second one is the logit for the correct answer - # convert logits to probabilities using softmax - # return list of floats instead of list of tensors - prm_probs_per_sample = [ - t.item() - for t in self.softmax(all_prm_logits[idx][prm_indices])[:, 1] - ] - prm_probs.append(prm_probs_per_sample) - - reward = sum(prm_probs_per_sample) / len(prm_probs_per_sample) - rewards.append(reward) - - return rewards, prm_probs - - def prepare_batch(self, raw_batch): - """ - Tokenize and prepare batch for forward pass - Expects a raw_batch of (question, list_of_steps). The list of steps is joined with the step_eos token - """ - - questions, list_of_steps = raw_batch - assert len(questions) == len(list_of_steps) - - inputs = [] - for i in range(len(questions)): - text_with_steps_marked = "" - - for step in list_of_steps[i]: - text_with_steps_marked += f"{step} {self.step_eos}" - - message = [ - {"role": "user", "content": questions[i]}, - {"role": "assistant", "content": text_with_steps_marked}, - ] - input = self.tokenizer.apply_chat_template(message, tokenize=False) - inputs.append(input) - - # tokenize data for the RM - batch = self.tokenizer( - inputs, return_tensors="pt", padding=True, truncation=True - ) - return batch diff --git a/mellea/stdlib/rewards/prm_scorer.py b/mellea/stdlib/rewards/prm_scorer.py index 9d56e62d..0c46dcbe 100644 --- a/mellea/stdlib/rewards/prm_scorer.py +++ b/mellea/stdlib/rewards/prm_scorer.py @@ -1,38 +1,20 @@ -import torch - +from mellea.backends.huggingface import HFProcessRewardModel from mellea.stdlib.base import CBlock, Context from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ScorerRequirement, ValidationResult -from mellea.stdlib.rewards.prm import ( - GenerativePRMForInference, - RegressionPRMForInference, -) class PRMScorer(ScorerRequirement): """A process reward model scorer based on local huggingface backend.""" def __init__( - self, - *, - model_version: str = "ibm-granite/granite-3.3-8b-lora-math-prm", - preference_ordering: str = "max", - device: str | None = None, - step_splitter="\n\n", - prm_type: str = "generative", - **prm_kwargs, + self, *, prm_model: HFProcessRewardModel, preference_ordering: str = "max" ): """ Args: - model_version: The version of the model, defaults to "ibm-granite/granite-3.3-8b-lora-math-prm". + prm_model: The PRM model preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min" - device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. - correct_token: PRM generated token that indicates step is correct - generation_prompt: Generation prompt required for the PRM scorer - step_splitter: string on which assistant response is split into steps - prm_type: type of prm tobe used. must be either `generative` or `regression` - prm_kwargs: args for PRM. For Generative, pass `correct_token`, `generation_prompt`. For Regression, pass `step_token` """ super().__init__( check_only=True, @@ -40,27 +22,7 @@ def __init__( preference_ordering=preference_ordering, ) - self._model_version = model_version - - # auto-device if not more specific - self._device = device - if device is None: - device_name: str = ( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" - ) - assert device_name is not None - self._device = torch.device(device_name) # type: ignore - - self.step_splitter = step_splitter - assert prm_type.lower() in ["generative", "regression"], ( - "prm_type must be either generative or regression" - ) - self.prm_type = prm_type.lower() - self.prm_kwargs = prm_kwargs + self.model: HFProcessRewardModel = prm_model def _prm_validate(self, ctx: Context): """ @@ -87,39 +49,9 @@ def _prm_validate(self, ctx: Context): assistant_content = last_turn.output.value - # convert assistant message into a list of steps - list_of_steps = [ - step.strip() - for step in assistant_content.split(self.step_splitter) - if step.strip != "" - ] - - # Load model - model: GenerativePRMForInference | RegressionPRMForInference - if self.prm_type == "generative": - model = GenerativePRMForInference( - model_path=self._model_version, - load_in_bf16=True, - device=self._device, - **self.prm_kwargs, - ) - model.to(self._device) - elif self.prm_type == "regression": - model = RegressionPRMForInference( - model_path=self._model_version, - load_in_bf16=True, - device=self._device, - **self.prm_kwargs, - ) # type: ignore[no-redef] - else: - raise NotImplementedError - - rewards, rewards_per_step = model(([user_query], [list_of_steps])) + rewards, rewards_per_step = self.model.score(user_query, assistant_content) # return single reward item for the response assert len(rewards) == 1 - # offload and delete model before returning rewards - del model - return ValidationResult(result=True, reason=None, score=rewards[0]) From 9ee660ea3d76bb28ed418c6cb18fc6f64cc0642f Mon Sep 17 00:00:00 2001 From: aashka-trivedi Date: Wed, 17 Sep 2025 13:43:08 +0000 Subject: [PATCH 6/6] use req.preference_ordering for BoN Sampling --- mellea/backends/litellm.py | 6 +++--- mellea/stdlib/sampling.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 031bae41..b330bbcd 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -5,9 +5,9 @@ from collections.abc import Callable from typing import Any -import litellm -import litellm.litellm_core_utils -import litellm.litellm_core_utils.get_supported_openai_params +import litellm # type: ignore +import litellm.litellm_core_utils # type: ignore +import litellm.litellm_core_utils.get_supported_openai_params # type: ignore import mellea.backends.model_ids as model_ids from mellea.backends import BaseModelSubclass diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index a771bc19..845e3730 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -511,18 +511,28 @@ def sample( # find max reward amongst results for which all requirements have passed if len(successful_sampled_scores) > 0: scores: list[float] = [] + scorer_preference_ordering = None for sample in successful_sampled_scores: for req, val_score in sample: if isinstance(req, ScorerRequirement): assert val_score._score is not None scores.append(val_score._score) + scorer_preference_ordering = req.preference_ordering assert len(successful_sampled_results) == len(scores) + assert scorer_preference_ordering is not None - best_result, best_score = max( - zip(successful_sampled_results, scores), key=lambda x: x[1] - ) + if scorer_preference_ordering == "max": + best_result, best_score = max( + zip(successful_sampled_results, scores), key=lambda x: x[1] + ) + elif scorer_preference_ordering == "min": + best_result, best_score = min( + zip(successful_sampled_results, scores), key=lambda x: x[1] + ) + else: + raise NotImplementedError return SamplingResult( best_result,