Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions apps/vllm/judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""To run:
export HF_HUB_DISABLE_XET=1
python -m apps.vllm.judge --config apps/vllm/llama3_8b.yaml
"""

import asyncio

import os

from typing import Callable

from forge.actors.generative_judge import GenerativeJudge
from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.provisioner import shutdown

from omegaconf import DictConfig
from vllm.outputs import RequestOutput

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"

from vllm.transformers_utils.tokenizer import get_tokenizer


def basic_selector_wrapper(model: str) -> Callable[[str, list[str]], str]:
"""
Note: This is not a "good" prompt setup, it just demonstrates how to make one
"""

def _wrapper(prompt: str, responses: list[str]) -> str:
tokenizer = get_tokenizer(model)
system_prompt = f"""
You are an expert evaluator. Evaluate the responses provided and return
a single integer indicating which response is the most factually correct.
Each response is formatted as [Response #<N>], where <N> represents the
selection. Do not explain your reasoning, just provide a number.

Here is the prompt that generated the responses: {prompt}.
"""
response_str = "\n".join(
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)]
)
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": response_str},
]
formatted_request = tokenizer.apply_chat_template(
as_chat, tokenize=False, add_generation_prompt=True
)
return formatted_request

return _wrapper


def unroll_response(response: RequestOutput) -> list[str]:
return [output.text for output in response.outputs]


async def run(cfg: DictConfig):
prompt = "What is the capital of Japan?"
responses = ["Aardvark", "Durian", "Tokyo"]

print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
evaluate = GenerativeJudge(
policy,
prompt_wrapper=basic_selector_wrapper(cfg.policy.engine_config.model),
output_postprocessor=unroll_response,
)

print(f"Prompt: {prompt}")
print(f"Responses: {responses}\n")

try:
async with policy.session():
print("Requesting generation ...")
evaluations = await evaluate.generate(
prompt=prompt,
responses=responses,
)

print("\nGeneration Results:")
print("=" * 80)
for batch, evaluation in enumerate(evaluations):
print(f"Sample {batch + 1}")
print(f"Evaluation: {evaluation}")
print("-" * 80)

finally:
print("\nShutting down...")
await policy.shutdown()
await shutdown()


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))


if __name__ == "__main__":
recipe_main()
2 changes: 1 addition & 1 deletion apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ policy:
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 2
n: 4
guided_decoding: false
max_tokens: 512

Expand Down
52 changes: 52 additions & 0 deletions src/forge/actors/generative_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Optional


@dataclass
class GenerativeJudge:
"""
Wrapper with custom prompting and post processing used for generative
Judging. Represents a single verifier which could be LLM based
RewardModels or LLM based judges.

- `RewardModels` are typically discriminative models posttrained to
evaluate responses. These models are specialized and need less prompting

- `LLM-based Judges` are typically generative models which are then prompted
to evaluate responses. These models NEED prompt engineering to evaluate
and may require more postprocessing
"""

# Typically Policy, but effectively any service with `generate` method
generator: ServiceInterface
prompt_wrapper: Optional[Callable[[str, list[str]], str]] = None
output_postprocessor: Optional[Callable[[Any], Any]] = None

def _wrap_prompt(self, prompt: str, responses: list[str]) -> str:
"""
Construct the string being passed to the generator
"""
if self.prompt_wrapper:
return self.prompt_wrapper(prompt, responses)
return prompt

def _postprocess_output(self, output: Any) -> Any:
"""
Postprocess generation results (metrics, aggregation, reducing)
"""
if self.output_postprocessor:
return self.output_postprocessor(output)
return output

async def generate(self, prompt: str, responses: list[str], priority: int = 0):
wrapped_prompt: str = self._wrap_prompt(prompt, responses)
response = await self.generator.generate.choose(prompt=wrapped_prompt)
return self._postprocess_output(response)
Loading