-
Notifications
You must be signed in to change notification settings - Fork 16
Creates Judge Example as a wrapper on Policy #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jack-Khuu
wants to merge
17
commits into
main
Choose a base branch
from
judge2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
249a6ab
Push basic GenerativeJudge example
Jack-Khuu be26c39
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu e88f58f
[Debug] Individual LLM/Reward Actors
Jack-Khuu 634fe59
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu 336c997
remove unused
Jack-Khuu 6a01bd7
debug
Jack-Khuu 8c87d42
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu f80ff68
Refactor to subclass policy
Jack-Khuu 53607fd
Light cleanup-still testing
Jack-Khuu 00bbffa
Need to test math
Jack-Khuu c5e7b07
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu f3ae7da
Psh to switch machines
Jack-Khuu fa18b3e
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu b67a95d
Clean up and simplify Judge
Jack-Khuu b5a9d70
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu b695b3b
Rebase typo
Jack-Khuu 269b3f9
Merge branch 'main' into judge2
Jack-Khuu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.judge --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import asyncio | ||
|
||
import os | ||
|
||
from forge.actors.generative_judge import LLMJudge | ||
from forge.cli.config import parse | ||
from forge.controller.provisioner import shutdown | ||
|
||
from omegaconf import DictConfig | ||
|
||
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" | ||
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" | ||
|
||
|
||
async def run(cfg: DictConfig): | ||
prompt = "What is the capital of Japan?" | ||
responses = ["Aardvark", "Durian", "Tokyo"] | ||
|
||
print("Spawning service...") | ||
judge_service_config = cfg.services.pop("judge") | ||
judge = await LLMJudge.options(**judge_service_config).as_service(policy_cfg=cfg) | ||
|
||
print(f"Prompt: {prompt}") | ||
print(f"Responses: {responses}\n") | ||
print("Requesting generation ...") | ||
evaluations: list[str] = await judge.generate.choose( | ||
prompt=prompt, | ||
responses=responses, | ||
) | ||
|
||
print("\nGeneration Results:") | ||
print("=" * 80) | ||
for batch, evaluation in enumerate(evaluations): | ||
print(f"Sample {batch + 1}") | ||
print(f"Evaluation: {evaluation}") | ||
print("-" * 80) | ||
|
||
print("\nShutting down...") | ||
await judge.shutdown() | ||
await shutdown() | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
recipe_main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional | ||
|
||
from monarch.actor import endpoint, ProcMesh | ||
|
||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
from forge.actors.policy import Policy | ||
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh | ||
from forge.data_models.completion import Completion | ||
|
||
|
||
@dataclass | ||
class LLMJudge(ForgeActor): | ||
""" | ||
`LLM-based Judges` are typically generative models which are then prompted | ||
to evaluate responses. These models NEED prompt engineering to evaluate | ||
and may require more postprocessing | ||
""" | ||
|
||
# Typically Policy, but effectively any service with `generate` method | ||
generator: ServiceInterface | None = None | ||
|
||
def __post_init__(self): | ||
super().__init__() | ||
self._judge_proc: ProcMesh | None = None | ||
|
||
@classmethod | ||
async def launch( | ||
cls: type["LLMJudge"], | ||
*, | ||
process_config: ProcessConfig, | ||
policy_cfg: Mapping, | ||
prompt_wrapper: Optional[Callable[[str, list[str]], str]] = None, | ||
output_postprocessor: Optional[Callable[[Any], Any]] = None, | ||
**kwargs, | ||
): | ||
judge_procs = await get_proc_mesh(process_config=process_config) | ||
policy = await Policy.options(**policy_cfg.services.policy).as_service( | ||
**policy_cfg.policy | ||
) | ||
|
||
actor_name = kwargs.pop("name", cls.__name__) | ||
llm_judge = await judge_procs.spawn( | ||
actor_name, | ||
cls, | ||
prompt_wrapper=prompt_wrapper, | ||
output_postprocessor=output_postprocessor, | ||
generator=policy, | ||
) | ||
llm_judge._judge_proc = judge_procs | ||
|
||
await llm_judge.setup.call() | ||
return llm_judge | ||
|
||
@endpoint | ||
async def setup(self): | ||
assert self.generator is not None, "Generator not initialized correctly" | ||
self.tokenizer = get_tokenizer(self.generator.engine_config.model) | ||
|
||
@classmethod | ||
async def shutdown(cls: type["LLMJudge"], actor: "LLMJudge"): | ||
assert ( | ||
actor.generator is not None | ||
), "Tried to shutdown a generator that was not initialized correctly" | ||
assert ( | ||
actor._judge_proc is not None | ||
), "Tried to shutdown a LLMJudge that was not initialized correctly" | ||
|
||
await actor.generator.shutdown() | ||
await stop_proc_mesh(actor._judge_proc) | ||
|
||
def _wrap_prompt(self, prompt: str, responses: list[str]) -> str: | ||
""" | ||
Construct the string being passed to the generator | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
system_prompt = f""" | ||
You are an expert evaluator. Evaluate the responses provided and return | ||
a single integer indicating which response is the most factually correct. | ||
Each response is formatted as [Response #<N>], where <N> represents the | ||
selection. Do not explain your reasoning, just provide a number. | ||
|
||
Here is the prompt that generated the responses: {prompt}. | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
formatted_request = self.tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _postprocess_output(self, output: List[Completion]) -> list[str]: | ||
return [output.text for output in response.outputs] | ||
|
||
@endpoint | ||
async def generate( | ||
self, prompt: str, responses: list[str], priority: int = 0 | ||
) -> list[str]: | ||
wrapped_prompt: str = self._wrap_prompt(prompt, responses) | ||
response: List[Completion] = await self.generator.generate.choose( | ||
prompt=wrapped_prompt | ||
) | ||
return self._postprocess_output(response) | ||
|
||
|
||
@dataclass | ||
class RewardModelJudge(ForgeActor): | ||
""" | ||
`RewardModels` are typically discriminative models, post trained to | ||
evaluate responses without further prompting required. | ||
""" | ||
|
||
# Typically Policy, but effectively any service with `generate` method | ||
generator: ServiceInterface | None = None | ||
|
||
def __post_init__(self): | ||
super().__init__() | ||
self._judge_proc: ProcMesh | None = None | ||
|
||
@classmethod | ||
async def launch( | ||
cls: type["LLMJudge"], | ||
*, | ||
process_config: ProcessConfig, | ||
policy_cfg: Mapping, | ||
prompt_wrapper: Optional[Callable[[str, list[str]], str]] = None, | ||
output_postprocessor: Optional[Callable[[Any], Any]] = None, | ||
**kwargs, | ||
): | ||
judge_procs = await get_proc_mesh(process_config=process_config) | ||
policy = await Policy.options(**policy_cfg.services.policy).as_service( | ||
**policy_cfg.policy | ||
) | ||
|
||
actor_name = kwargs.pop("name", cls.__name__) | ||
llm_judge = await judge_procs.spawn( | ||
actor_name, | ||
cls, | ||
prompt_wrapper=prompt_wrapper, | ||
output_postprocessor=output_postprocessor, | ||
generator=policy, | ||
) | ||
llm_judge._judge_proc = judge_procs | ||
|
||
return llm_judge | ||
|
||
# TODO: Add formatting for reward models | ||
def _wrap_prompt(self, prompt: str, responses: list[str]) -> str: | ||
return prompt | ||
|
||
@endpoint | ||
async def generate( | ||
self, prompt: str, responses: list[str], priority: int = 0 | ||
) -> list[str]: | ||
wrapped_prompt: str = self._wrap_prompt(prompt, responses) | ||
return await self.generator.generate.choose(prompt=wrapped_prompt) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.