Skip to content

Commit 1caeda0

Browse files
authored
Improve serialization of LLMJudge and custom evaluators (#1367)
1 parent ea34fc9 commit 1caeda0

File tree

7 files changed

+105
-21
lines changed

7 files changed

+105
-21
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# `pydantic_evals.evaluators`
22

33
::: pydantic_evals.evaluators
4+
5+
::: pydantic_evals.evaluators.llm_as_a_judge

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class OpenAIModel(Model):
150150
"""
151151

152152
client: AsyncOpenAI = field(repr=False)
153-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
153+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
154154

155155
_model_name: OpenAIModelName = field(repr=False)
156156
_system: str = field(default='openai', repr=False)

pydantic_evals/pydantic_evals/evaluators/common.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> bool:
155155

156156
@dataclass
157157
class LLMJudge(Evaluator[object, object, object]):
158-
"""Judge whether the output of a language model meets the criteria of a provided rubric."""
158+
"""Judge whether the output of a language model meets the criteria of a provided rubric.
159+
160+
If you do not specify a model, it uses the default model for judging. This starts as 'openai:gpt-4o', but can be
161+
overridden by calling [`set_default_judge_model`][pydantic_evals.evaluators.llm_as_a_judge.set_default_judge_model].
162+
"""
159163

160164
rubric: str
161-
model: models.Model | models.KnownModelName = 'openai:gpt-4o'
165+
model: models.Model | models.KnownModelName | None = None
162166
include_input: bool = False
163167

164168
async def evaluate(
@@ -175,6 +179,17 @@ async def evaluate(
175179
grading_output = await judge_output(ctx.output, self.rubric, self.model)
176180
return EvaluationReason(value=grading_output.pass_, reason=grading_output.reason)
177181

182+
def build_serialization_arguments(self):
183+
result = super().build_serialization_arguments()
184+
# always serialize the model as a string when present; use its name if it's a KnownModelName
185+
if (model := result.get('model')) and isinstance(model, models.Model):
186+
result['model'] = f'{model.system}:{model.model_name}'
187+
188+
# Note: this may lead to confusion if you try to serialize-then-deserialize with a custom model.
189+
# I expect that is rare enough to be worth not solving yet, but common enough that we probably will want to
190+
# solve it eventually. I'm imagining some kind of model registry, but don't want to work out the details yet.
191+
return result
192+
178193

179194
@dataclass
180195
class HasMatchingSpan(Evaluator[object, object, object]):

pydantic_evals/pydantic_evals/evaluators/evaluator.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,28 @@ def serialize(self, info: SerializationInfo) -> Any:
223223
Returns:
224224
A JSON-serializable representation of this evaluator as an EvaluatorSpec.
225225
"""
226+
raw_arguments = self.build_serialization_arguments()
227+
228+
arguments: None | tuple[Any,] | dict[str, Any]
229+
if len(raw_arguments) == 0:
230+
arguments = None
231+
elif len(raw_arguments) == 1:
232+
arguments = (next(iter(raw_arguments.values())),)
233+
else:
234+
arguments = raw_arguments
235+
return to_jsonable_python(
236+
EvaluatorSpec(name=self.name(), arguments=arguments), context=info.context, serialize_unknown=True
237+
)
238+
239+
def build_serialization_arguments(self) -> dict[str, Any]:
240+
"""Build the arguments for serialization.
241+
242+
Evaluators are serialized for inclusion as the "source" in an `EvaluationResult`.
243+
If you want to modify how the evaluator is serialized for that or other purposes, you can override this method.
244+
245+
Returns:
246+
A dictionary of arguments to be used during serialization.
247+
"""
226248
raw_arguments: dict[str, Any] = {}
227249
for field in fields(self):
228250
value = getattr(self, field.name)
@@ -234,12 +256,4 @@ def serialize(self, info: SerializationInfo) -> Any:
234256
if value == field.default_factory():
235257
continue
236258
raw_arguments[field.name] = value
237-
238-
arguments: None | tuple[Any,] | dict[str, Any]
239-
if len(raw_arguments) == 0:
240-
arguments = None
241-
elif len(raw_arguments) == 1:
242-
arguments = (next(iter(raw_arguments.values())),)
243-
else:
244-
arguments = raw_arguments
245-
return to_jsonable_python(EvaluatorSpec(name=self.name(), arguments=arguments), context=info.context)
259+
return raw_arguments

pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from pydantic_ai import Agent, models
1010

11-
__all__ = ('GradingOutput', 'judge_input_output', 'judge_output')
11+
__all__ = ('GradingOutput', 'judge_input_output', 'judge_output', 'set_default_judge_model')
12+
13+
14+
_default_model: models.Model | models.KnownModelName = 'openai:gpt-4o'
1215

1316

1417
class GradingOutput(BaseModel, populate_by_name=True):
@@ -41,11 +44,15 @@ class GradingOutput(BaseModel, populate_by_name=True):
4144

4245

4346
async def judge_output(
44-
output: Any, rubric: str, model: models.Model | models.KnownModelName = 'openai:gpt-4o'
47+
output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None
4548
) -> GradingOutput:
46-
"""Judge the output of a model based on a rubric."""
49+
"""Judge the output of a model based on a rubric.
50+
51+
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
52+
but this can be changed using the `set_default_judge_model` function.
53+
"""
4754
user_prompt = f'<Output>\n{_stringify(output)}\n</Output>\n<Rubric>\n{rubric}\n</Rubric>'
48-
return (await _judge_output_agent.run(user_prompt, model=model)).data
55+
return (await _judge_output_agent.run(user_prompt, model=model or _default_model)).data
4956

5057

5158
_judge_input_output_agent = Agent(
@@ -72,11 +79,24 @@ async def judge_output(
7279

7380

7481
async def judge_input_output(
75-
inputs: Any, output: Any, rubric: str, model: models.Model | models.KnownModelName = 'openai:gpt-4o'
82+
inputs: Any, output: Any, rubric: str, model: models.Model | models.KnownModelName | None = None
7683
) -> GradingOutput:
77-
"""Judge the output of a model based on the inputs and a rubric."""
84+
"""Judge the output of a model based on the inputs and a rubric.
85+
86+
If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o',
87+
but this can be changed using the `set_default_judge_model` function.
88+
"""
7889
user_prompt = f'<Input>\n{_stringify(inputs)}\n</Input>\n<Output>\n{_stringify(output)}\n</Output>\n<Rubric>\n{rubric}\n</Rubric>'
79-
return (await _judge_input_output_agent.run(user_prompt, model=model)).data
90+
return (await _judge_input_output_agent.run(user_prompt, model=model or _default_model)).data
91+
92+
93+
def set_default_judge_model(model: models.Model | models.KnownModelName) -> None: # pragma: no cover
94+
"""Set the default model used for judging.
95+
96+
This model is used if `None` is passed to the `model` argument of `judge_output` and `judge_input_output`.
97+
"""
98+
global _default_model
99+
_default_model = model
80100

81101

82102
def _stringify(value: Any) -> str:

tests/evals/test_evaluator_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ async def test_llm_judge_evaluator(mocker: MockerFixture):
222222
assert result.value is True
223223
assert result.reason == 'Test passed'
224224

225-
mock_judge_output.assert_called_once_with('Hello world', 'Content contains a greeting', 'openai:gpt-4o')
225+
mock_judge_output.assert_called_once_with('Hello world', 'Content contains a greeting', None)
226226

227227
# Test with input
228-
evaluator = LLMJudge(rubric='Output contains input', include_input=True)
228+
evaluator = LLMJudge(rubric='Output contains input', include_input=True, model='openai:gpt-4o')
229229
result = await evaluator.evaluate(ctx)
230230
assert isinstance(result, EvaluationReason)
231231
assert result.value is True

tests/evals/test_evaluators.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from inline_snapshot import snapshot
88
from pydantic import BaseModel, TypeAdapter
99

10+
from pydantic_ai.messages import ModelMessage, ModelResponse
11+
from pydantic_ai.models import Model, ModelRequestParameters
12+
from pydantic_ai.settings import ModelSettings
13+
from pydantic_ai.usage import Usage
14+
1015
from ..conftest import try_import
1116

1217
with try_import() as imports_successful:
@@ -108,6 +113,34 @@ async def test_evaluator_spec_serialization():
108113
assert adapter.dump_python(spec_single_arg, context={'use_short_form': True}) == snapshot({'MyEvaluator': 'value1'})
109114

110115

116+
async def test_llm_judge_serialization():
117+
# Ensure models are serialized based on their system + name when used with LLMJudge
118+
119+
class MyModel(Model):
120+
async def request(
121+
self,
122+
messages: list[ModelMessage],
123+
model_settings: ModelSettings | None,
124+
model_request_parameters: ModelRequestParameters,
125+
) -> tuple[ModelResponse, Usage]:
126+
raise NotImplementedError
127+
128+
@property
129+
def model_name(self) -> str:
130+
return 'my-model'
131+
132+
@property
133+
def system(self) -> str:
134+
return 'my-system'
135+
136+
adapter = TypeAdapter(Evaluator)
137+
138+
assert adapter.dump_python(LLMJudge(rubric='my rubric', model=MyModel())) == {
139+
'name': 'LLMJudge',
140+
'arguments': {'model': 'my-system:my-model', 'rubric': 'my rubric'},
141+
}
142+
143+
111144
async def test_evaluator_call(test_context: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
112145
"""Test calling an Evaluator."""
113146

0 commit comments

Comments
 (0)