-
Notifications
You must be signed in to change notification settings - Fork 24
Adds basic LLMJudges #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds basic LLMJudges #167
Changes from all commits
eefb9b2
3e57be6
b256f32
df9b32d
c3f2ec5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
try: | ||
from vllm.outputs import RequestOutput | ||
except ImportError as e: | ||
print(f"Failed to import RequestOutput from vllm.outputs: {e}") | ||
RequestOutput = "RequestOutput" | ||
|
||
from forge.controller.service.interface import ServiceInterface | ||
|
||
|
||
class EvaluationMethodology(str, Enum): | ||
"""Evaluation methodology for LLM Judge.""" | ||
|
||
MAJORITY = "Majority" | ||
FIRST_SAMPLE = "First" | ||
PASS_N = "Pass N" | ||
|
||
|
||
@dataclass | ||
class LLMJudge: | ||
"""Simple interface for Judges utilizing LLMs.""" | ||
|
||
judge_model: ServiceInterface | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of this class holding a ServiceInterface, can we instead pass the generated responses directly to evaluate()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So have the main loop directly manage calling all N weak verifiers in the rollout? I don't have a strong preference here, but doing so does increases the boilerplate/management load There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, so this is my mental model of how Weaver works:
Meanwhile, for pass@1, majority, first sample, and pass @k, they're more like "assuming we know the answer already, was the generator able to produce the correct results in K tries?" so when it comes to this PR, it depends on what we're trying to accomplish - is it step 2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Step 2 (will update to return a list/tensor of length K) where N-judges correspond to N verifiers Good catch though, I should generalize this to K responses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, so if we're shooting for 2 then I wouldn't necessarily include pass@1, majority, first sample, pass@K etc. which are separate from the verifiers What we should show is like N different models from RewardBench as individual services if we want to introduce a Judge concept, then I think we should do two things:
|
||
methodology: EvaluationMethodology = EvaluationMethodology.MAJORITY | ||
|
||
async def _generate(self, prompt: str) -> RequestOutput: | ||
"""Internally generate responses.""" | ||
return await self.judge_model.generate.choose(prompt=prompt) | ||
|
||
async def evaluate_response(self, prompt: str, response: str) -> float: | ||
"""Evaluate a response to a prompt.""" | ||
outputs: RequestOutput = await self._generate(prompt) | ||
match self.methodology: | ||
case EvaluationMethodology.MAJORITY: | ||
return await self._majority_vote(response, outputs) | ||
case EvaluationMethodology.FIRST_SAMPLE: | ||
return await self._first_sample(response, outputs) | ||
case EvaluationMethodology.PASS_N: | ||
return await self._pass_n(response, outputs) | ||
case _: | ||
raise ValueError(f"Unknown evaluation methodology: {self.methodology}") | ||
|
||
async def _majority_vote(self, response: str, outputs: RequestOutput) -> bool: | ||
""" | ||
Return whether at least half of the outputs match the response | ||
""" | ||
matching = 0 | ||
response_normalized = response.lower().strip() | ||
|
||
for output in outputs.outputs: | ||
output_normalized = output.text.lower().strip() | ||
if response_normalized == output_normalized: | ||
matching += 1 | ||
print(output.text) | ||
|
||
return matching > (len(outputs.outputs) // 2) | ||
|
||
async def _first_sample(self, response: str, outputs: RequestOutput) -> bool: | ||
""" | ||
Returns whether there is a match to the first output | ||
""" | ||
first_output = outputs.outputs[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: may be handle the edge case when the outputs are empty? |
||
output_normalized = first_output.text.lower().strip() | ||
response_normalized = response.lower().strip() | ||
|
||
return output_normalized == response_normalized | ||
|
||
async def _pass_n(self, response: str, outputs: RequestOutput) -> bool: | ||
""" | ||
Return whether any of the outputs match the response | ||
""" | ||
response_normalized = response.lower().strip() | ||
|
||
for output in outputs.outputs: | ||
output_normalized = output.text.lower().strip() | ||
if response_normalized == output_normalized: | ||
return True | ||
|
||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import List | ||
from unittest.mock import AsyncMock, Mock, patch | ||
|
||
import pytest | ||
from forge.controller.service.interface import ServiceInterface | ||
|
||
from forge.data.judge import EvaluationMethodology, LLMJudge | ||
|
||
|
||
# Mock classes to simulate VLLM RequestOutput structure | ||
@dataclass | ||
class MockCompletionOutput: | ||
text: str | ||
|
||
|
||
@dataclass | ||
class MockRequestOutput: | ||
outputs: List[MockCompletionOutput] | ||
|
||
|
||
class TestLLMJudge: | ||
@pytest.fixture | ||
def mock_service(self): | ||
"""Create a mock ServiceInterface for testing.""" | ||
service = Mock(spec=ServiceInterface) | ||
service.generate = AsyncMock() | ||
return service | ||
|
||
@pytest.fixture | ||
def judge_majority(self, mock_service): | ||
return LLMJudge( | ||
judge_model=mock_service, methodology=EvaluationMethodology.MAJORITY | ||
) | ||
|
||
@pytest.fixture | ||
def judge_first_sample(self, mock_service): | ||
"""Create an LLMJudge with FIRST_SAMPLE methodology.""" | ||
return LLMJudge( | ||
judge_model=mock_service, methodology=EvaluationMethodology.FIRST_SAMPLE | ||
) | ||
|
||
@pytest.fixture | ||
def judge_pass_n(self, mock_service): | ||
"""Create an LLMJudge with PASS_N methodology.""" | ||
return LLMJudge( | ||
judge_model=mock_service, methodology=EvaluationMethodology.PASS_N | ||
) | ||
|
||
@pytest.mark.asyncio | ||
async def test_majority_vote_true_case(self, judge_majority): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="yes"), # matches | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="YES"), # matches (case insensitive) | ||
MockCompletionOutput(text="yes "), # matches (stripped) | ||
MockCompletionOutput(text="maybe"), # doesn't match | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_majority, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_majority.evaluate_response("What is 2+2?", "yes") | ||
assert result is True | ||
|
||
@pytest.mark.asyncio | ||
async def test_majority_vote_false_case(self, judge_majority): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="yes"), # matches | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="maybe"), # doesn't match | ||
MockCompletionOutput(text="YES"), # matches (case insensitive) | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_majority, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_majority.evaluate_response("What is 2+2?", "yes") | ||
assert result is False | ||
|
||
@pytest.mark.asyncio | ||
async def test_first_sample_true_case(self, judge_first_sample): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="YES"), # matches (case insensitive) | ||
MockCompletionOutput(text="no"), # doesn't matter | ||
MockCompletionOutput(text="maybe"), # doesn't matter | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_first_sample, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") | ||
assert result is True | ||
|
||
@pytest.mark.asyncio | ||
async def test_first_sample_false_case(self, judge_first_sample): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="yes"), # doesn't matter | ||
MockCompletionOutput(text="YES"), # doesn't matter | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_first_sample, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_first_sample.evaluate_response("What is 2+2?", "yes") | ||
assert result is False | ||
|
||
@pytest.mark.asyncio | ||
async def test_pass_n_true_case(self, judge_pass_n): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="maybe"), # doesn't match | ||
MockCompletionOutput(text="YES"), # matches (case insensitive) | ||
MockCompletionOutput(text="no"), # doesn't match | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): | ||
result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") | ||
assert result is True | ||
|
||
@pytest.mark.asyncio | ||
async def test_pass_n_false_case(self, judge_pass_n): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="maybe"), # doesn't match | ||
MockCompletionOutput(text="four"), # doesn't match | ||
MockCompletionOutput(text="nope"), # doesn't match | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object(judge_pass_n, "_generate", return_value=mock_request_output): | ||
result = await judge_pass_n.evaluate_response("What is 2+2?", "yes") | ||
assert result is False | ||
|
||
@pytest.mark.asyncio | ||
async def test_case_insensitive_and_whitespace_handling(self, judge_majority): | ||
mock_outputs = [ | ||
MockCompletionOutput(text="YES"), # matches | ||
MockCompletionOutput(text=" yes "), # matches (with whitespace) | ||
MockCompletionOutput(text="Yes"), # matches | ||
MockCompletionOutput(text="no"), # doesn't match | ||
MockCompletionOutput(text="NO"), # doesn't match | ||
] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_majority, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_majority.evaluate_response("What is 2+2?", " YES ") | ||
assert result is True | ||
|
||
@pytest.mark.asyncio | ||
async def test_empty_outputs_handling(self, judge_majority): | ||
"""Test handling of empty outputs list.""" | ||
mock_outputs = [] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object( | ||
judge_majority, "_generate", return_value=mock_request_output | ||
): | ||
result = await judge_majority.evaluate_response("What is 2+2?", "yes") | ||
assert result is False # 0 out of 0 match, which is not > 0//2 = 0 | ||
|
||
@pytest.mark.asyncio | ||
async def test_unknown_evaluation_methodology(self, mock_service): | ||
"""Test that unknown evaluation methodology raises ValueError.""" | ||
judge = LLMJudge(judge_model=mock_service, methodology="INVALID") | ||
|
||
mock_outputs = [MockCompletionOutput(text="yes")] | ||
mock_request_output = MockRequestOutput(outputs=mock_outputs) | ||
|
||
with patch.object(judge, "_generate", return_value=mock_request_output): | ||
with pytest.raises( | ||
ValueError, match="Unknown evaluation methodology: INVALID" | ||
): | ||
await judge.evaluate_response("What is 2+2?", "yes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if LLMJudge here is evaluating responses i.e. Pass@1, Majority, First Sample, etc. - aren't these metrics rather than judges?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't metrics just for logging?
The Judges can be used as part of the generation evaluation (analogous to Rewards)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah Metric is sort of a loaded term unfortunately. Maybe it's
EvaluationMetric
?