Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions apps/vllm/judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""To run:
export HF_HUB_DISABLE_XET=1
python -m apps.vllm.judge --config apps/vllm/llama3_8b.yaml
"""

import asyncio

import os

from forge.actors.generative_judge import LLMJudge
from forge.cli.config import parse
from forge.controller.provisioner import shutdown

from omegaconf import DictConfig

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


async def run(cfg: DictConfig):
prompt = "What is the capital of Japan?"
responses = ["Aardvark", "Durian", "Tokyo"]

print("Spawning service...")
judge_service_config = cfg.services.pop("judge")
judge = await LLMJudge.options(**judge_service_config).as_service(policy_cfg=cfg)

print(f"Prompt: {prompt}")
print(f"Responses: {responses}\n")
print("Requesting generation ...")
evaluations: list[str] = await judge.generate.choose(
prompt=prompt,
responses=responses,
)

print("\nGeneration Results:")
print("=" * 80)
for batch, evaluation in enumerate(evaluations):
print(f"Sample {batch + 1}")
print(f"Evaluation: {evaluation}")
print("-" * 80)

print("\nShutting down...")
await judge.shutdown()
await shutdown()


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))


if __name__ == "__main__":
recipe_main()
8 changes: 6 additions & 2 deletions apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ policy:
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 2
n: 4
guided_decoding: false
max_tokens: 512

services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 4
num_replicas: 2
with_gpus: true
judge:
procs: 1
num_replicas: 1
with_gpus: false


# Optional, otherwise argparse fallback kicks in
Expand Down
172 changes: 172 additions & 0 deletions src/forge/actors/generative_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Optional

from monarch.actor import endpoint, ProcMesh

from vllm.transformers_utils.tokenizer import get_tokenizer

from forge.actors.policy import Policy
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
from forge.data_models.completion import Completion


@dataclass
class LLMJudge(ForgeActor):
"""
`LLM-based Judges` are typically generative models which are then prompted
to evaluate responses. These models NEED prompt engineering to evaluate
and may require more postprocessing
"""

# Typically Policy, but effectively any service with `generate` method
generator: ServiceInterface | None = None

def __post_init__(self):
super().__init__()
self._judge_proc: ProcMesh | None = None

@classmethod
async def launch(
cls: type["LLMJudge"],
*,
process_config: ProcessConfig,
policy_cfg: Mapping,
prompt_wrapper: Optional[Callable[[str, list[str]], str]] = None,
output_postprocessor: Optional[Callable[[Any], Any]] = None,
**kwargs,
):
judge_procs = await get_proc_mesh(process_config=process_config)
policy = await Policy.options(**policy_cfg.services.policy).as_service(
**policy_cfg.policy
)

actor_name = kwargs.pop("name", cls.__name__)
llm_judge = await judge_procs.spawn(
actor_name,
cls,
prompt_wrapper=prompt_wrapper,
output_postprocessor=output_postprocessor,
generator=policy,
)
llm_judge._judge_proc = judge_procs

await llm_judge.setup.call()
return llm_judge

@endpoint
async def setup(self):
assert self.generator is not None, "Generator not initialized correctly"
self.tokenizer = get_tokenizer(self.generator.engine_config.model)

@classmethod
async def shutdown(cls: type["LLMJudge"], actor: "LLMJudge"):
assert (
actor.generator is not None
), "Tried to shutdown a generator that was not initialized correctly"
assert (
actor._judge_proc is not None
), "Tried to shutdown a LLMJudge that was not initialized correctly"

await actor.generator.shutdown()
await stop_proc_mesh(actor._judge_proc)

def _wrap_prompt(self, prompt: str, responses: list[str]) -> str:
"""
Construct the string being passed to the generator

Note: This is not a "good" prompt, it just demonstrates how to make one
"""

system_prompt = f"""
You are an expert evaluator. Evaluate the responses provided and return
a single integer indicating which response is the most factually correct.
Each response is formatted as [Response #<N>], where <N> represents the
selection. Do not explain your reasoning, just provide a number.

Here is the prompt that generated the responses: {prompt}.
"""
response_str = "\n".join(
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)]
)
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": response_str},
]
formatted_request = self.tokenizer.apply_chat_template(
as_chat, tokenize=False, add_generation_prompt=True
)
return formatted_request

def _postprocess_output(self, output: List[Completion]) -> list[str]:
return [output.text for output in response.outputs]

@endpoint
async def generate(
self, prompt: str, responses: list[str], priority: int = 0
) -> list[str]:
wrapped_prompt: str = self._wrap_prompt(prompt, responses)
response: List[Completion] = await self.generator.generate.choose(
prompt=wrapped_prompt
)
return self._postprocess_output(response)


@dataclass
class RewardModelJudge(ForgeActor):
"""
`RewardModels` are typically discriminative models, post trained to
evaluate responses without further prompting required.
"""

# Typically Policy, but effectively any service with `generate` method
generator: ServiceInterface | None = None

def __post_init__(self):
super().__init__()
self._judge_proc: ProcMesh | None = None

@classmethod
async def launch(
cls: type["LLMJudge"],
*,
process_config: ProcessConfig,
policy_cfg: Mapping,
prompt_wrapper: Optional[Callable[[str, list[str]], str]] = None,
output_postprocessor: Optional[Callable[[Any], Any]] = None,
**kwargs,
):
judge_procs = await get_proc_mesh(process_config=process_config)
policy = await Policy.options(**policy_cfg.services.policy).as_service(
**policy_cfg.policy
)

actor_name = kwargs.pop("name", cls.__name__)
llm_judge = await judge_procs.spawn(
actor_name,
cls,
prompt_wrapper=prompt_wrapper,
output_postprocessor=output_postprocessor,
generator=policy,
)
llm_judge._judge_proc = judge_procs

return llm_judge

# TODO: Add formatting for reward models
def _wrap_prompt(self, prompt: str, responses: list[str]) -> str:
return prompt

@endpoint
async def generate(
self, prompt: str, responses: list[str], priority: int = 0
) -> list[str]:
wrapped_prompt: str = self._wrap_prompt(prompt, responses)
return await self.generator.generate.choose(prompt=wrapped_prompt)
Loading