Skip to content

Commit ebd86fa

Browse files
authored
Fix some issues with non-serializable inputs in evals (#1333)
1 parent a5a471e commit ebd86fa

File tree

6 files changed

+225
-50
lines changed

6 files changed

+225
-50
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import yaml
2727
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_serializer
2828
from pydantic._internal import _typing_extra
29-
from pydantic_core import to_json, to_jsonable_python
29+
from pydantic_core import to_json
3030
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler
3131
from typing_extensions import NotRequired, Self, TypedDict, TypeVar
3232

@@ -907,11 +907,9 @@ async def _run_task_and_evaluators(
907907
span_id = f'{context.span_id:016x}'
908908
fallback_duration = time.time() - t0
909909

910-
report_inputs = to_jsonable_python(case.inputs)
911-
912910
return ReportCase(
913911
name=report_case_name,
914-
inputs=report_inputs,
912+
inputs=case.inputs,
915913
metadata=case.metadata,
916914
expected_output=case.expected_output,
917915
output=scoring_context.output,

pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def _stringify(value: Any) -> str:
8383
if isinstance(value, str):
8484
return value
8585
try:
86+
# If the value can be serialized to JSON, use that.
87+
# If that behavior is undesirable, the user could manually call repr on the arguments to the judge_* functions
8688
return to_json(value).decode()
8789
except Exception:
8890
return repr(value)

pydantic_evals/pydantic_evals/reporting/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,15 @@ def print(
158158
width: int | None = None,
159159
baseline: EvaluationReport | None = None,
160160
include_input: bool = False,
161+
include_metadata: bool = False,
162+
include_expected_output: bool = False,
161163
include_output: bool = False,
162164
include_durations: bool = True,
163165
include_total_duration: bool = False,
164166
include_removed_cases: bool = False,
165167
include_averages: bool = True,
166168
input_config: RenderValueConfig | None = None,
169+
metadata_config: RenderValueConfig | None = None,
167170
output_config: RenderValueConfig | None = None,
168171
score_configs: dict[str, RenderNumberConfig] | None = None,
169172
label_configs: dict[str, RenderValueConfig] | None = None,
@@ -177,12 +180,15 @@ def print(
177180
table = self.console_table(
178181
baseline=baseline,
179182
include_input=include_input,
183+
include_metadata=include_metadata,
184+
include_expected_output=include_expected_output,
180185
include_output=include_output,
181186
include_durations=include_durations,
182187
include_total_duration=include_total_duration,
183188
include_removed_cases=include_removed_cases,
184189
include_averages=include_averages,
185190
input_config=input_config,
191+
metadata_config=metadata_config,
186192
output_config=output_config,
187193
score_configs=score_configs,
188194
label_configs=label_configs,
@@ -195,12 +201,15 @@ def console_table(
195201
self,
196202
baseline: EvaluationReport | None = None,
197203
include_input: bool = False,
204+
include_metadata: bool = False,
205+
include_expected_output: bool = False,
198206
include_output: bool = False,
199207
include_durations: bool = True,
200208
include_total_duration: bool = False,
201209
include_removed_cases: bool = False,
202210
include_averages: bool = True,
203211
input_config: RenderValueConfig | None = None,
212+
metadata_config: RenderValueConfig | None = None,
204213
output_config: RenderValueConfig | None = None,
205214
score_configs: dict[str, RenderNumberConfig] | None = None,
206215
label_configs: dict[str, RenderValueConfig] | None = None,
@@ -213,12 +222,15 @@ def console_table(
213222
"""
214223
renderer = EvaluationRenderer(
215224
include_input=include_input,
225+
include_metadata=include_metadata,
226+
include_expected_output=include_expected_output,
216227
include_output=include_output,
217228
include_durations=include_durations,
218229
include_total_duration=include_total_duration,
219230
include_removed_cases=include_removed_cases,
220231
include_averages=include_averages,
221232
input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
233+
metadata_config={**_DEFAULT_VALUE_CONFIG, **(metadata_config or {})},
222234
output_config=output_config or _DEFAULT_VALUE_CONFIG,
223235
score_configs=score_configs or {},
224236
label_configs=label_configs or {},
@@ -496,6 +508,8 @@ def render_diff(self, name: str | None, old: T_contra | None, new: T_contra | No
496508
@dataclass
497509
class ReportCaseRenderer:
498510
include_input: bool
511+
include_metadata: bool
512+
include_expected_output: bool
499513
include_output: bool
500514
include_scores: bool
501515
include_labels: bool
@@ -505,6 +519,7 @@ class ReportCaseRenderer:
505519
include_total_duration: bool
506520

507521
input_renderer: _ValueRenderer
522+
metadata_renderer: _ValueRenderer
508523
output_renderer: _ValueRenderer
509524
score_renderers: dict[str, _NumberRenderer]
510525
label_renderers: dict[str, _ValueRenderer]
@@ -517,6 +532,10 @@ def build_base_table(self, title: str) -> Table:
517532
table.add_column('Case ID', style='bold')
518533
if self.include_input:
519534
table.add_column('Inputs', overflow='fold')
535+
if self.include_metadata:
536+
table.add_column('Metadata', overflow='fold')
537+
if self.include_expected_output:
538+
table.add_column('Expected Output', overflow='fold')
520539
if self.include_output:
521540
table.add_column('Outputs', overflow='fold')
522541
if self.include_scores:
@@ -538,6 +557,12 @@ def build_row(self, case: ReportCase) -> list[str]:
538557
if self.include_input:
539558
row.append(self.input_renderer.render_value(None, case.inputs) or EMPTY_CELL_STR)
540559

560+
if self.include_metadata:
561+
row.append(self.input_renderer.render_value(None, case.metadata) or EMPTY_CELL_STR)
562+
563+
if self.include_expected_output:
564+
row.append(self.input_renderer.render_value(None, case.expected_output) or EMPTY_CELL_STR)
565+
541566
if self.include_output:
542567
row.append(self.output_renderer.render_value(None, case.output) or EMPTY_CELL_STR)
543568

@@ -565,6 +590,12 @@ def build_aggregate_row(self, aggregate: ReportCaseAggregate) -> list[str]:
565590
if self.include_input:
566591
row.append(EMPTY_AGGREGATE_CELL_STR)
567592

593+
if self.include_metadata:
594+
row.append(EMPTY_AGGREGATE_CELL_STR)
595+
596+
if self.include_expected_output:
597+
row.append(EMPTY_AGGREGATE_CELL_STR)
598+
568599
if self.include_output:
569600
row.append(EMPTY_AGGREGATE_CELL_STR)
570601

@@ -598,6 +629,19 @@ def build_diff_row(
598629
input_diff = self.input_renderer.render_diff(None, baseline.inputs, new_case.inputs) or EMPTY_CELL_STR
599630
row.append(input_diff)
600631

632+
if self.include_metadata:
633+
metadata_diff = (
634+
self.metadata_renderer.render_diff(None, baseline.metadata, new_case.metadata) or EMPTY_CELL_STR
635+
)
636+
row.append(metadata_diff)
637+
638+
if self.include_expected_output:
639+
expected_output_diff = (
640+
self.output_renderer.render_diff(None, baseline.expected_output, new_case.expected_output)
641+
or EMPTY_CELL_STR
642+
)
643+
row.append(expected_output_diff)
644+
601645
if self.include_output:
602646
output_diff = self.output_renderer.render_diff(None, baseline.output, new_case.output) or EMPTY_CELL_STR
603647
row.append(output_diff)
@@ -642,6 +686,12 @@ def build_diff_aggregate_row(
642686
if self.include_input:
643687
row.append(EMPTY_AGGREGATE_CELL_STR)
644688

689+
if self.include_metadata:
690+
row.append(EMPTY_AGGREGATE_CELL_STR)
691+
692+
if self.include_expected_output:
693+
row.append(EMPTY_AGGREGATE_CELL_STR)
694+
645695
if self.include_output:
646696
row.append(EMPTY_AGGREGATE_CELL_STR)
647697

@@ -777,6 +827,8 @@ class EvaluationRenderer:
777827

778828
# Columns to include
779829
include_input: bool
830+
include_metadata: bool
831+
include_expected_output: bool
780832
include_output: bool
781833
include_durations: bool
782834
include_total_duration: bool
@@ -786,6 +838,7 @@ class EvaluationRenderer:
786838
include_averages: bool
787839

788840
input_config: RenderValueConfig
841+
metadata_config: RenderValueConfig
789842
output_config: RenderValueConfig
790843
score_configs: dict[str, RenderNumberConfig]
791844
label_configs: dict[str, RenderValueConfig]
@@ -820,6 +873,7 @@ def _get_case_renderer(
820873
self, report: EvaluationReport, baseline: EvaluationReport | None = None
821874
) -> ReportCaseRenderer:
822875
input_renderer = _ValueRenderer.from_config(self.input_config)
876+
metadata_renderer = _ValueRenderer.from_config(self.metadata_config)
823877
output_renderer = _ValueRenderer.from_config(self.output_config)
824878
score_renderers = self._infer_score_renderers(report, baseline)
825879
label_renderers = self._infer_label_renderers(report, baseline)
@@ -830,6 +884,8 @@ def _get_case_renderer(
830884

831885
return ReportCaseRenderer(
832886
include_input=self.include_input,
887+
include_metadata=self.include_metadata,
888+
include_expected_output=self.include_expected_output,
833889
include_output=self.include_output,
834890
include_scores=self.include_scores(report, baseline),
835891
include_labels=self.include_labels(report, baseline),
@@ -838,6 +894,7 @@ def _get_case_renderer(
838894
include_durations=self.include_durations,
839895
include_total_duration=self.include_total_duration,
840896
input_renderer=input_renderer,
897+
metadata_renderer=metadata_renderer,
841898
output_renderer=output_renderer,
842899
score_renderers=score_renderers,
843900
label_renderers=label_renderers,

tests/evals/test_dataset.py

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pydantic import BaseModel
1414

1515
from ..conftest import try_import
16+
from .utils import render_table
1617

1718
with try_import() as imports_successful:
1819
from pydantic_evals import Case, Dataset
@@ -342,40 +343,42 @@ async def my_task(inputs: TaskInput) -> TaskOutput:
342343
return TaskOutput(answer=f'answer to {inputs.query}')
343344

344345
report = await example_dataset.evaluate(my_task)
345-
assert report.cases == [
346-
ReportCase(
347-
name='case1',
348-
inputs={'query': 'What is 2+2?'},
349-
metadata=TaskMetadata(difficulty='easy', category='general'),
350-
expected_output=TaskOutput(answer='4', confidence=1.0),
351-
output=TaskOutput(answer='answer to What is 2+2?', confidence=1.0),
352-
metrics={'chars': 12},
353-
attributes={'is_about_france': False},
354-
scores={},
355-
labels={},
356-
assertions={},
357-
task_duration=1.0,
358-
total_duration=3.0,
359-
trace_id='00000000000000000000000000000001',
360-
span_id='0000000000000003',
361-
),
362-
ReportCase(
363-
name='case2',
364-
inputs={'query': 'What is the capital of France?'},
365-
metadata=TaskMetadata(difficulty='medium', category='geography'),
366-
expected_output=TaskOutput(answer='Paris', confidence=1.0),
367-
output=TaskOutput(answer='answer to What is the capital of France?', confidence=1.0),
368-
metrics={'chars': 30},
369-
attributes={'is_about_france': True},
370-
scores={},
371-
labels={},
372-
assertions={},
373-
task_duration=1.0,
374-
total_duration=3.0,
375-
trace_id='00000000000000000000000000000001',
376-
span_id='0000000000000007',
377-
),
378-
]
346+
assert report.cases == snapshot(
347+
[
348+
ReportCase(
349+
name='case1',
350+
inputs=TaskInput(query='What is 2+2?'),
351+
metadata=TaskMetadata(difficulty='easy', category='general'),
352+
expected_output=TaskOutput(answer='4', confidence=1.0),
353+
output=TaskOutput(answer='answer to What is 2+2?', confidence=1.0),
354+
metrics={'chars': 12},
355+
attributes={'is_about_france': False},
356+
scores={},
357+
labels={},
358+
assertions={},
359+
task_duration=1.0,
360+
total_duration=3.0,
361+
trace_id='00000000000000000000000000000001',
362+
span_id='0000000000000003',
363+
),
364+
ReportCase(
365+
name='case2',
366+
inputs=TaskInput(query='What is the capital of France?'),
367+
metadata=TaskMetadata(difficulty='medium', category='geography'),
368+
expected_output=TaskOutput(answer='Paris', confidence=1.0),
369+
output=TaskOutput(answer='answer to What is the capital of France?', confidence=1.0),
370+
metrics={'chars': 30},
371+
attributes={'is_about_france': True},
372+
scores={},
373+
labels={},
374+
assertions={},
375+
task_duration=1.0,
376+
total_duration=3.0,
377+
trace_id='00000000000000000000000000000001',
378+
span_id='0000000000000007',
379+
),
380+
]
381+
)
379382

380383

381384
async def test_repeated_name_outputs(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata]):
@@ -393,7 +396,7 @@ async def my_task(inputs: TaskInput) -> TaskOutput:
393396
[
394397
ReportCase(
395398
name='case1',
396-
inputs={'query': 'What is 2+2?'},
399+
inputs=TaskInput(query='What is 2+2?'),
397400
metadata=TaskMetadata(difficulty='easy', category='general'),
398401
expected_output=TaskOutput(answer='4', confidence=1.0),
399402
output=TaskOutput(answer='answer to What is 2+2?', confidence=1.0),
@@ -419,7 +422,7 @@ async def my_task(inputs: TaskInput) -> TaskOutput:
419422
),
420423
ReportCase(
421424
name='case2',
422-
inputs={'query': 'What is the capital of France?'},
425+
inputs=TaskInput(query='What is the capital of France?'),
423426
metadata=TaskMetadata(difficulty='medium', category='geography'),
424427
expected_output=TaskOutput(answer='Paris', confidence=1.0),
425428
output=TaskOutput(answer='answer to What is the capital of France?', confidence=1.0),
@@ -467,7 +470,7 @@ async def my_task(inputs: TaskInput) -> TaskOutput:
467470
[
468471
ReportCase(
469472
name='case1',
470-
inputs={'query': 'What is 2+2?'},
473+
inputs=TaskInput(query='What is 2+2?'),
471474
metadata=TaskMetadata(difficulty='easy', category='general'),
472475
expected_output=TaskOutput(answer='4', confidence=1.0),
473476
output=TaskOutput(answer='answer to What is 2+2?', confidence=1.0),
@@ -483,7 +486,7 @@ async def my_task(inputs: TaskInput) -> TaskOutput:
483486
),
484487
ReportCase(
485488
name='case2',
486-
inputs={'query': 'What is the capital of France?'},
489+
inputs=TaskInput(query='What is the capital of France?'),
487490
metadata=TaskMetadata(difficulty='medium', category='geography'),
488491
expected_output=TaskOutput(answer='Paris', confidence=1.0),
489492
output=TaskOutput(answer='answer to What is the capital of France?', confidence=1.0),
@@ -988,3 +991,47 @@ def test_import_generate_dataset():
988991
from pydantic_evals.generation import generate_dataset
989992

990993
assert generate_dataset
994+
995+
996+
def test_evaluate_non_serializable_inputs():
997+
@dataclass
998+
class MyInputs:
999+
result_type: type[str] | type[int]
1000+
1001+
my_dataset = Dataset[MyInputs, Any, Any](
1002+
cases=[
1003+
Case(
1004+
name='str',
1005+
inputs=MyInputs(result_type=str),
1006+
expected_output='abc',
1007+
),
1008+
Case(
1009+
name='int',
1010+
inputs=MyInputs(result_type=int),
1011+
expected_output=123,
1012+
),
1013+
],
1014+
)
1015+
1016+
async def my_task(my_inputs: MyInputs) -> int | str:
1017+
if issubclass(my_inputs.result_type, str):
1018+
return my_inputs.result_type('abc')
1019+
else:
1020+
return my_inputs.result_type(123)
1021+
1022+
report = my_dataset.evaluate_sync(task=my_task)
1023+
assert [c.inputs for c in report.cases] == snapshot([MyInputs(result_type=str), MyInputs(result_type=int)])
1024+
1025+
table = report.console_table(include_input=True)
1026+
assert render_table(table) == snapshot("""\
1027+
Evaluation Summary: my_task
1028+
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
1029+
┃ Case ID ┃ Inputs ┃ Duration ┃
1030+
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
1031+
│ str │ test_evaluate_non_serializable_inputs.<locals>.MyInputs(result_type=<class 'str'>) │ 1.0s │
1032+
├──────────┼────────────────────────────────────────────────────────────────────────────────────┼──────────┤
1033+
│ int │ test_evaluate_non_serializable_inputs.<locals>.MyInputs(result_type=<class 'int'>) │ 1.0s │
1034+
├──────────┼────────────────────────────────────────────────────────────────────────────────────┼──────────┤
1035+
│ Averages │ │ 1.0s │
1036+
└──────────┴────────────────────────────────────────────────────────────────────────────────────┴──────────┘
1037+
""")

0 commit comments

Comments
 (0)