Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/forge/data/judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum

try:
from vllm.outputs import RequestOutput
except ImportError as e:
print(f"Failed to import RequestOutput from vllm.outputs: {e}")
RequestOutput = "RequestOutput"

from forge.controller.service.interface import ServiceInterface


class EvaluationMethodology(str, Enum):
"""Evaluation methodology for LLM Judge."""

MAJORITY = "Majority"
FIRST_SAMPLE = "First"
PASS_N = "Pass N"


@dataclass
class LLMJudge:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if LLMJudge here is evaluating responses i.e. Pass@1, Majority, First Sample, etc. - aren't these metrics rather than judges?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't metrics just for logging?

The Judges can be used as part of the generation evaluation (analogous to Rewards)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah Metric is sort of a loaded term unfortunately. Maybe it's EvaluationMetric?

"""Simple interface for Judges utilizing LLMs."""

judge_model: ServiceInterface
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of this class holding a ServiceInterface, can we instead pass the generated responses directly to evaluate()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So have the main loop directly manage calling all N weak verifiers in the rollout?

I don't have a strong preference here, but doing so does increases the boilerplate/management load

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, so this is my mental model of how Weaver works:

  1. Generator generates K responses
  2. K responses go through N verifiers, producing KN verifier results
  3. KN results gets distilled down to a scalar 0/1 through weaver

Meanwhile, for pass@1, majority, first sample, and pass @k, they're more like "assuming we know the answer already, was the generator able to produce the correct results in K tries?"

so when it comes to this PR, it depends on what we're trying to accomplish - is it step 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step 2 (will update to return a list/tensor of length K) where N-judges correspond to N verifiers

Good catch though, I should generalize this to K responses

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so if we're shooting for 2 then I wouldn't necessarily include pass@1, majority, first sample, pass@K etc. which are separate from the verifiers

What we should show is like N different models from RewardBench as individual services

if we want to introduce a Judge concept, then I think we should do two things:

  • Rename Policy* to Inference*
  • Make the Judge a special instance of Policy that uses generate to turn the final result into scalars or w/e is needed

methodology: EvaluationMethodology = EvaluationMethodology.MAJORITY

async def _generate(self, prompt: str) -> RequestOutput:
"""Internally generate responses."""
return await self.judge_model.generate.choose(prompt=prompt)

async def evaluate_response(self, prompt: str, response: str) -> float:
"""Evaluate a response to a prompt."""
outputs: RequestOutput = await self._generate(prompt)
match self.methodology:
case EvaluationMethodology.MAJORITY:
return await self._majority_vote(response, outputs)
case EvaluationMethodology.FIRST_SAMPLE:
return await self._first_sample(response, outputs)
case EvaluationMethodology.PASS_N:
return await self._pass_n(response, outputs)
case _:
raise ValueError(f"Unknown evaluation methodology: {self.methodology}")

async def _majority_vote(self, response: str, outputs: RequestOutput) -> bool:
"""
Return whether at least half of the outputs match the response
"""
matching = 0
response_normalized = response.lower().strip()

for output in outputs.outputs:
output_normalized = output.text.lower().strip()
if response_normalized == output_normalized:
matching += 1
print(output.text)

return matching > (len(outputs.outputs) // 2)

async def _first_sample(self, response: str, outputs: RequestOutput) -> bool:
"""
Returns whether there is a match to the first output
"""
first_output = outputs.outputs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: may be handle the edge case when the outputs are empty?

output_normalized = first_output.text.lower().strip()
response_normalized = response.lower().strip()

return output_normalized == response_normalized

async def _pass_n(self, response: str, outputs: RequestOutput) -> bool:
"""
Return whether any of the outputs match the response
"""
response_normalized = response.lower().strip()

for output in outputs.outputs:
output_normalized = output.text.lower().strip()
if response_normalized == output_normalized:
return True

return False
189 changes: 189 additions & 0 deletions tests/unit_tests/data/test_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List
from unittest.mock import AsyncMock, Mock, patch

import pytest
from forge.controller.service.interface import ServiceInterface

from forge.data.judge import EvaluationMethodology, LLMJudge


# Mock classes to simulate VLLM RequestOutput structure
@dataclass
class MockCompletionOutput:
text: str


@dataclass
class MockRequestOutput:
outputs: List[MockCompletionOutput]


class TestLLMJudge:
@pytest.fixture
def mock_service(self):
"""Create a mock ServiceInterface for testing."""
service = Mock(spec=ServiceInterface)
service.generate = AsyncMock()
return service

@pytest.fixture
def judge_majority(self, mock_service):
return LLMJudge(
judge_model=mock_service, methodology=EvaluationMethodology.MAJORITY
)

@pytest.fixture
def judge_first_sample(self, mock_service):
"""Create an LLMJudge with FIRST_SAMPLE methodology."""
return LLMJudge(
judge_model=mock_service, methodology=EvaluationMethodology.FIRST_SAMPLE
)

@pytest.fixture
def judge_pass_n(self, mock_service):
"""Create an LLMJudge with PASS_N methodology."""
return LLMJudge(
judge_model=mock_service, methodology=EvaluationMethodology.PASS_N
)

@pytest.mark.asyncio
async def test_majority_vote_true_case(self, judge_majority):
mock_outputs = [
MockCompletionOutput(text="yes"), # matches
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="YES"), # matches (case insensitive)
MockCompletionOutput(text="yes "), # matches (stripped)
MockCompletionOutput(text="maybe"), # doesn't match
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_majority, "_generate", return_value=mock_request_output
):
result = await judge_majority.evaluate_response("What is 2+2?", "yes")
assert result is True

@pytest.mark.asyncio
async def test_majority_vote_false_case(self, judge_majority):
mock_outputs = [
MockCompletionOutput(text="yes"), # matches
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="maybe"), # doesn't match
MockCompletionOutput(text="YES"), # matches (case insensitive)
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_majority, "_generate", return_value=mock_request_output
):
result = await judge_majority.evaluate_response("What is 2+2?", "yes")
assert result is False

@pytest.mark.asyncio
async def test_first_sample_true_case(self, judge_first_sample):
mock_outputs = [
MockCompletionOutput(text="YES"), # matches (case insensitive)
MockCompletionOutput(text="no"), # doesn't matter
MockCompletionOutput(text="maybe"), # doesn't matter
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_first_sample, "_generate", return_value=mock_request_output
):
result = await judge_first_sample.evaluate_response("What is 2+2?", "yes")
assert result is True

@pytest.mark.asyncio
async def test_first_sample_false_case(self, judge_first_sample):
mock_outputs = [
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="yes"), # doesn't matter
MockCompletionOutput(text="YES"), # doesn't matter
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_first_sample, "_generate", return_value=mock_request_output
):
result = await judge_first_sample.evaluate_response("What is 2+2?", "yes")
assert result is False

@pytest.mark.asyncio
async def test_pass_n_true_case(self, judge_pass_n):
mock_outputs = [
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="maybe"), # doesn't match
MockCompletionOutput(text="YES"), # matches (case insensitive)
MockCompletionOutput(text="no"), # doesn't match
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(judge_pass_n, "_generate", return_value=mock_request_output):
result = await judge_pass_n.evaluate_response("What is 2+2?", "yes")
assert result is True

@pytest.mark.asyncio
async def test_pass_n_false_case(self, judge_pass_n):
mock_outputs = [
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="maybe"), # doesn't match
MockCompletionOutput(text="four"), # doesn't match
MockCompletionOutput(text="nope"), # doesn't match
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(judge_pass_n, "_generate", return_value=mock_request_output):
result = await judge_pass_n.evaluate_response("What is 2+2?", "yes")
assert result is False

@pytest.mark.asyncio
async def test_case_insensitive_and_whitespace_handling(self, judge_majority):
mock_outputs = [
MockCompletionOutput(text="YES"), # matches
MockCompletionOutput(text=" yes "), # matches (with whitespace)
MockCompletionOutput(text="Yes"), # matches
MockCompletionOutput(text="no"), # doesn't match
MockCompletionOutput(text="NO"), # doesn't match
]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_majority, "_generate", return_value=mock_request_output
):
result = await judge_majority.evaluate_response("What is 2+2?", " YES ")
assert result is True

@pytest.mark.asyncio
async def test_empty_outputs_handling(self, judge_majority):
"""Test handling of empty outputs list."""
mock_outputs = []
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(
judge_majority, "_generate", return_value=mock_request_output
):
result = await judge_majority.evaluate_response("What is 2+2?", "yes")
assert result is False # 0 out of 0 match, which is not > 0//2 = 0

@pytest.mark.asyncio
async def test_unknown_evaluation_methodology(self, mock_service):
"""Test that unknown evaluation methodology raises ValueError."""
judge = LLMJudge(judge_model=mock_service, methodology="INVALID")

mock_outputs = [MockCompletionOutput(text="yes")]
mock_request_output = MockRequestOutput(outputs=mock_outputs)

with patch.object(judge, "_generate", return_value=mock_request_output):
with pytest.raises(
ValueError, match="Unknown evaluation methodology: INVALID"
):
await judge.evaluate_response("What is 2+2?", "yes")
Loading