|
30 | 30 | from mellea.backends.cache import Cache, SimpleLRUCache |
31 | 31 | from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter |
32 | 32 | from mellea.backends.model_ids import ModelIdentifier |
| 33 | +from mellea.backends.process_reward_models import PRM |
33 | 34 | from mellea.backends.tools import ( |
34 | 35 | add_tools_from_context_actions, |
35 | 36 | add_tools_from_model_options, |
@@ -672,3 +673,56 @@ def __init__( |
672 | 673 | self._generation_prompt_tokens = self._backend._tokenizer( |
673 | 674 | self._generation_prompt, return_tensors="pt" |
674 | 675 | ).to(self._backend._device) |
| 676 | + |
| 677 | + |
| 678 | +class HFProcessRewardModel(PRM, abc.ABC): |
| 679 | + def __init__( |
| 680 | + self, model_name_or_path: str, score_token: str, device: str | None = None |
| 681 | + ): |
| 682 | + """Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models |
| 683 | +
|
| 684 | + Args: |
| 685 | + model_name_or_path (str): A local path to PRM or a huggingface PRM |
| 686 | + score_token (str): token who's logits correspond to the PRM score. Can be a step demarker (for non-generative PRMs) or a correctness indicator (for generative PRMs) |
| 687 | + device (str): device: The computational device to use ("cuda" for GPU, "mps" for Apple Silicon, or "cpu"), defaults to None. If not specified, the best available device will be automatically selected. |
| 688 | + """ |
| 689 | + super().__init__(model_name_or_path) |
| 690 | + |
| 691 | + # auto-device if not more specific |
| 692 | + self._device = device |
| 693 | + if device is None: |
| 694 | + device_name: str = ( |
| 695 | + "cuda" |
| 696 | + if torch.cuda.is_available() |
| 697 | + else "mps" |
| 698 | + if torch.backends.mps.is_available() |
| 699 | + else "cpu" |
| 700 | + ) |
| 701 | + assert device_name is not None |
| 702 | + self._device = torch.device(device_name) # type: ignore |
| 703 | + |
| 704 | + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( |
| 705 | + self.model_name_or_path, torch_dtype=torch.bfloat16 |
| 706 | + ) |
| 707 | + self.model.to(self._device) # type: ignore |
| 708 | + self.model.eval() |
| 709 | + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) |
| 710 | + |
| 711 | + self._score_token = score_token |
| 712 | + self._score_token_id = self.tokenizer.encode( |
| 713 | + self._score_token, add_special_tokens=False |
| 714 | + )[0] |
| 715 | + |
| 716 | + def stepify(self, content: str, step_separator: str) -> list[str]: |
| 717 | + """Splits the assistant response into steps to score |
| 718 | +
|
| 719 | + Args: |
| 720 | + content: assistant response to score |
| 721 | + step_separator: string on which to separate the content into steps |
| 722 | + """ |
| 723 | + |
| 724 | + # convert assistant message into a list of steps |
| 725 | + list_of_steps = [ |
| 726 | + step.strip() for step in content.split(step_separator) if step.strip != "" |
| 727 | + ] |
| 728 | + return list_of_steps |
0 commit comments