-
Notifications
You must be signed in to change notification settings - Fork 16
Creates Judge Example as a wrapper on Policy #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jack-Khuu
wants to merge
17
commits into
main
Choose a base branch
from
judge2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
249a6ab
Push basic GenerativeJudge example
Jack-Khuu be26c39
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu e88f58f
[Debug] Individual LLM/Reward Actors
Jack-Khuu 634fe59
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu 336c997
remove unused
Jack-Khuu 6a01bd7
debug
Jack-Khuu 8c87d42
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu f80ff68
Refactor to subclass policy
Jack-Khuu 53607fd
Light cleanup-still testing
Jack-Khuu 00bbffa
Need to test math
Jack-Khuu c5e7b07
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu f3ae7da
Psh to switch machines
Jack-Khuu fa18b3e
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu b67a95d
Clean up and simplify Judge
Jack-Khuu b5a9d70
Merge remote-tracking branch 'origin/main' into judge2
Jack-Khuu b695b3b
Rebase typo
Jack-Khuu 269b3f9
Merge branch 'main' into judge2
Jack-Khuu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""To run: | ||
export HF_HUB_DISABLE_XET=1 | ||
python -m apps.vllm.judge --config apps/vllm/llama3_8b.yaml | ||
""" | ||
|
||
import asyncio | ||
|
||
import os | ||
|
||
from forge.actors.judge import EvaluationMode, Judge | ||
from forge.cli.config import parse | ||
from forge.controller.provisioner import shutdown | ||
|
||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
from omegaconf import DictConfig | ||
|
||
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" | ||
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" | ||
|
||
|
||
async def run(cfg: DictConfig): | ||
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) | ||
mlogger = await get_or_create_metric_logger() | ||
await mlogger.init_backends.call_one(metric_logging_cfg) | ||
|
||
prompt = "What is the capital of Japan?" | ||
responses = ["Aardvark", "Durian", "Tokyo"] | ||
|
||
print("Spawning service...") | ||
judge = await Judge.options(**cfg.services.policy).as_service(**cfg.policy) | ||
|
||
print(f"Prompt: {prompt}") | ||
print(f"Responses: {responses}\n") | ||
print("Evaluating responses...") | ||
best_response_evaluations: list[str] = await judge.evaluate.route( | ||
prompt=prompt, responses=responses, evaluation_mode=EvaluationMode.BEST_RESPONSE | ||
) | ||
response_check_evaluations: list[str] = await judge.evaluate.route( | ||
prompt=prompt, | ||
responses=responses, | ||
evaluation_mode=EvaluationMode.RESPONSE_CHECK, | ||
) | ||
|
||
print("\nGeneration Results:") | ||
print("=" * 80) | ||
for batch, (best, fact) in enumerate( | ||
zip(best_response_evaluations, response_check_evaluations) | ||
): | ||
print(f"Sample {batch + 1}") | ||
print(f"Evaluation (BEST_RESPONSE): {best}") | ||
print(f"Evaluation (RESPONSE_CHECK): {fact}") | ||
print("-" * 80) | ||
|
||
print("\nShutting down...") | ||
await judge.shutdown() | ||
await shutdown() | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
recipe_main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from enum import auto, Enum | ||
|
||
from monarch.actor import endpoint | ||
|
||
from forge.actors.policy import Policy | ||
from forge.data_models.completion import Completion | ||
|
||
|
||
class EvaluationMode(Enum): | ||
"""Enum for selecting how a judge should evaluate the provided args""" | ||
|
||
BEST_RESPONSE = auto() | ||
RESPONSE_CHECK = auto() | ||
MATH_CHECK = auto() | ||
|
||
|
||
@dataclass | ||
class Judge(Policy): | ||
""" | ||
`LLM-based Judges` are typically generative models which are then prompted | ||
to evaluate responses. These models NEED prompt engineering to evaluate | ||
and may require more postprocessing | ||
""" | ||
|
||
def _math_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) -> str: | ||
""" | ||
Construct the generator input. Formats the request such that the generator | ||
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation | ||
for each response, corresponding to whether the model thinks the response | ||
matches the provided ground_truth. Specifically the generator is prompted to | ||
check for mathematical equivalence | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
if ground_truth is None: | ||
raise | ||
|
||
system_prompt = f""" | ||
You are a math professor. Given the prompt and ground truth solution, evaluate | ||
each of the provided attempts and return whether the final solution is | ||
numerically equivalent to the ground truth. | ||
|
||
Each response is formatted as [Response #<N>], where <N> represents the | ||
attempt. | ||
|
||
Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]", | ||
corresponding to the same order as the reponses provided. | ||
|
||
- If the answer is irrelevant to the prompt, return "[[BAD]]". | ||
- If you are not confident that solution and attempt are equivalent, return "[[BAD]]" | ||
- Only return "[[GOOD]]" if the attempt is numerically equivalent | ||
|
||
Do not explain your reasoning, just provide your evaluations. | ||
--- | ||
Here is the prompt that generated the responses: {prompt}. | ||
--- | ||
Here is the ground truth: {ground_truth} | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _response_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
) -> str: | ||
""" | ||
Construct the generator input. Formats the request such that the generator | ||
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation | ||
for each response, corresponding to whether the model thinks it correct | ||
answers the prompt. | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
system_prompt = f""" | ||
You are an expert fact checker. Given a prompt and response attempts, evaluate | ||
each attempt and return whether it accurately answers the prompt. | ||
Each response is formatted as [Response #<N>], where <N> represents the | ||
attempt. | ||
|
||
Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]", | ||
corresponding to the same order as the reponses provided. | ||
|
||
- If the answer is irrelevant to the prompt, return "[[BAD]]". | ||
- If you are not confident that the answer accurately answers the prompt, return "[[BAD]]" | ||
- Only return "[[GOOD]]" if the attempt accurately answers the prompt | ||
|
||
Do not explain your reasoning, just provide your evaluations. | ||
Here is the prompt that generated the responses: {prompt}. | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _best_check( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
ground_truth: None | str = None, | ||
) -> str: | ||
""" | ||
Construct the generator input. Format the request such that the generator | ||
will respond with a single integer corresponding to the response the model | ||
thinks is most factually correct. | ||
|
||
Note: This is not a "good" prompt, it just demonstrates how to make one | ||
""" | ||
|
||
system_prompt = f""" | ||
You are an expert evaluator. Evaluate the responses provided and return | ||
a single integer indicating which response is the most factually correct. | ||
Each response is formatted as [Response #<N>], where <N> represents the | ||
selection. Do not explain your reasoning, just provide a number. | ||
|
||
Here is the prompt that generated the responses: {prompt}. | ||
""" | ||
response_str = "\n".join( | ||
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)] | ||
) | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": response_str}, | ||
] | ||
tokenizer = self.processor.tokenizer.tokenizer | ||
formatted_request = tokenizer.apply_chat_template( | ||
as_chat, tokenize=False, add_generation_prompt=True | ||
) | ||
return formatted_request | ||
|
||
def _postprocess_output(self, outputs: list[Completion]) -> list[str]: | ||
return [output.text for output in outputs] | ||
|
||
@endpoint | ||
async def evaluate( | ||
self, | ||
prompt: str, | ||
responses: None | list[str] = None, | ||
ground_truth: None | str = None, | ||
evaluation_mode: EvaluationMode = EvaluationMode.BEST_RESPONSE, | ||
) -> list[str]: | ||
_prompting: dict = { | ||
EvaluationMode.BEST_RESPONSE: self._best_check, | ||
EvaluationMode.RESPONSE_CHECK: self._response_check, | ||
EvaluationMode.MATH_CHECK: self._math_check, | ||
} | ||
|
||
wrapped_prompt: str = _prompting[evaluation_mode]( | ||
prompt, responses, ground_truth | ||
) | ||
response: List[Completion] = await self.generate._method(self, wrapped_prompt) | ||
return self._postprocess_output(response) | ||
|
||
|
||
@dataclass | ||
class RewardModelJudge(Policy): | ||
""" | ||
`RewardModels` are typically discriminative models, post trained to | ||
evaluate responses without further prompting required. | ||
""" | ||
|
||
# TODO: Add reward models formatting | ||
def wrapped_prompt( | ||
self, prompt: str, responses: list[str], ground_truth: None | str = None | ||
) -> str: | ||
return prompt | ||
|
||
def _postprocess_output( | ||
self, outputs: list[Completion], ground_truth: None | str = None | ||
) -> list[str]: | ||
return [output.text for output in outputs] | ||
|
||
@endpoint | ||
async def evaluate( | ||
self, | ||
prompt: str, | ||
responses: list[str], | ||
) -> list[str]: | ||
wrapped_prompt: str = self._wrap_prompt(prompt, responses) | ||
response: List[Completion] = await self.generate._method(self, wrapped_prompt) | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return self._postprocess_output(response) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.