Skip to content

Commit 3cbe4dc

Browse files
committed
Add metadata for experiments
1 parent 2c3a218 commit 3cbe4dc

File tree

3 files changed

+494
-11
lines changed

3 files changed

+494
-11
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
348348
name=name,
349349
cases=cases,
350350
failures=failures,
351+
experiment_metadata=metadata,
351352
span_id=span_id,
352353
trace_id=trace_id,
353354
)

pydantic_evals/pydantic_evals/reporting/__init__.py

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from typing import Any, Generic, Literal, Protocol, cast
88

99
from pydantic import BaseModel, TypeAdapter
10-
from rich.console import Console
10+
from rich.console import Console, Group, RenderableType
11+
from rich.panel import Panel
1112
from rich.table import Table
13+
from rich.text import Text
1214
from typing_extensions import TypedDict, TypeVar
1315

1416
from pydantic_evals._utils import UNSET, Unset
@@ -196,6 +198,8 @@ class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
196198
failures: list[ReportCaseFailure[InputsT, OutputT, MetadataT]] = field(default_factory=list)
197199
"""The failures in the report. These are cases where task execution raised an exception."""
198200

201+
experiment_metadata: dict[str, Any] | None = None
202+
"""Metadata associated with the specific experiment represented by this report."""
199203
trace_id: str | None = None
200204
"""The trace ID of the evaluation."""
201205
span_id: str | None = None
@@ -261,7 +265,6 @@ def render(
261265
duration_config=duration_config,
262266
include_reasons=include_reasons,
263267
)
264-
Console(file=io_file)
265268
return io_file.getvalue()
266269

267270
def print(
@@ -297,7 +300,8 @@ def print(
297300
if console is None: # pragma: no branch
298301
console = Console(width=width)
299302

300-
table = self.console_table(
303+
metadata_panel = self._metadata_panel(baseline=baseline)
304+
renderable: RenderableType = self.console_table(
301305
baseline=baseline,
302306
include_input=include_input,
303307
include_metadata=include_metadata,
@@ -316,8 +320,12 @@ def print(
316320
metric_configs=metric_configs,
317321
duration_config=duration_config,
318322
include_reasons=include_reasons,
323+
with_title=not metadata_panel,
319324
)
320-
console.print(table)
325+
# Wrap table with experiment metadata panel if present
326+
if metadata_panel:
327+
renderable = Group(metadata_panel, renderable)
328+
console.print(renderable)
321329
if include_errors and self.failures: # pragma: no cover
322330
failures_table = self.failures_table(
323331
include_input=include_input,
@@ -330,6 +338,7 @@ def print(
330338
)
331339
console.print(failures_table, style='red')
332340

341+
# TODO(DavidM): in v2, change the return type here to RenderableType
333342
def console_table(
334343
self,
335344
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
@@ -351,9 +360,11 @@ def console_table(
351360
metric_configs: dict[str, RenderNumberConfig] | None = None,
352361
duration_config: RenderNumberConfig | None = None,
353362
include_reasons: bool = False,
363+
with_title: bool = True,
354364
) -> Table:
355-
"""Return a table containing the data from this report, or the diff between this report and a baseline report.
365+
"""Return a table containing the data from this report.
356366
367+
If a baseline is provided, returns a diff between this report and the baseline report.
357368
Optionally include input and output details.
358369
"""
359370
renderer = EvaluationRenderer(
@@ -378,10 +389,82 @@ def console_table(
378389
include_reasons=include_reasons,
379390
)
380391
if baseline is None:
381-
return renderer.build_table(self)
392+
return renderer.build_table(self, with_title=with_title)
382393
else: # pragma: no cover
383-
return renderer.build_diff_table(self, baseline)
394+
return renderer.build_diff_table(self, baseline, with_title=with_title)
384395

396+
def _metadata_panel(
397+
self, baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None
398+
) -> RenderableType | None:
399+
"""Wrap a table with an experiment metadata panel if metadata exists.
400+
401+
Args:
402+
table: The table to wrap
403+
baseline: Optional baseline report for diff metadata
404+
405+
Returns:
406+
Either the table unchanged or a Group with Panel and Table
407+
"""
408+
if baseline is None:
409+
# Single report - show metadata if present
410+
if self.experiment_metadata:
411+
metadata_text = Text()
412+
items = list(self.experiment_metadata.items())
413+
for i, (key, value) in enumerate(items):
414+
metadata_text.append(f'{key}: {value}', style='dim')
415+
if i < len(items) - 1:
416+
metadata_text.append('\n')
417+
return Panel(
418+
metadata_text,
419+
title=f'Evaluation Summary: {self.name}',
420+
title_align='left',
421+
border_style='dim',
422+
padding=(0, 1),
423+
expand=False,
424+
)
425+
else:
426+
# Diff report - show metadata diff if either has metadata
427+
if self.experiment_metadata or baseline.experiment_metadata:
428+
diff_name = baseline.name if baseline.name == self.name else f'{baseline.name}{self.name}'
429+
metadata_text = Text()
430+
lines_styles: list[tuple[str, str]] = []
431+
if baseline.experiment_metadata and self.experiment_metadata:
432+
# Collect all keys from both
433+
all_keys = sorted(set(baseline.experiment_metadata.keys()) | set(self.experiment_metadata.keys()))
434+
for key in all_keys:
435+
baseline_val = baseline.experiment_metadata.get(key)
436+
report_val = self.experiment_metadata.get(key)
437+
if baseline_val == report_val:
438+
lines_styles.append((f'{key}: {report_val}', 'dim'))
439+
elif baseline_val is None:
440+
lines_styles.append((f'+ {key}: {report_val}', 'green'))
441+
elif report_val is None:
442+
lines_styles.append((f'- {key}: {baseline_val}', 'red'))
443+
else:
444+
lines_styles.append((f'{key}: {baseline_val}{report_val}', 'yellow'))
445+
elif self.experiment_metadata:
446+
lines_styles = [(f'+ {k}: {v}', 'green') for k, v in self.experiment_metadata.items()]
447+
else: # baseline.experiment_metadata only
448+
assert baseline.experiment_metadata is not None
449+
lines_styles = [(f'- {k}: {v}', 'red') for k, v in baseline.experiment_metadata.items()]
450+
451+
for i, (line, style) in enumerate(lines_styles):
452+
metadata_text.append(line, style=style)
453+
if i < len(lines_styles) - 1:
454+
metadata_text.append('\n')
455+
456+
return Panel(
457+
metadata_text,
458+
title=f'Evaluation Diff: {diff_name}',
459+
title_align='left',
460+
border_style='dim',
461+
padding=(0, 1),
462+
expand=False,
463+
)
464+
465+
return None
466+
467+
# TODO(DavidM): in v2, change the return type here to RenderableType
385468
def failures_table(
386469
self,
387470
*,
@@ -705,6 +788,7 @@ class ReportCaseRenderer:
705788
metric_renderers: Mapping[str, _NumberRenderer]
706789
duration_renderer: _NumberRenderer
707790

791+
# TODO(DavidM): in v2, change the return type here to RenderableType
708792
def build_base_table(self, title: str) -> Table:
709793
"""Build and return a Rich Table for the diff output."""
710794
table = Table(title=title, show_lines=True)
@@ -731,6 +815,7 @@ def build_base_table(self, title: str) -> Table:
731815
table.add_column('Durations' if self.include_total_duration else 'Duration', justify='right')
732816
return table
733817

818+
# TODO(DavidM): in v2, change the return type here to RenderableType
734819
def build_failures_table(self, title: str) -> Table:
735820
"""Build and return a Rich Table for the failures output."""
736821
table = Table(title=title, show_lines=True)
@@ -1190,9 +1275,22 @@ def _get_case_renderer(
11901275
duration_renderer=duration_renderer,
11911276
)
11921277

1193-
def build_table(self, report: EvaluationReport) -> Table:
1278+
# TODO(DavidM): in v2, change the return type here to RenderableType
1279+
def build_table(self, report: EvaluationReport, *, with_title: bool = True) -> Table:
1280+
"""Build a table for the report.
1281+
1282+
Args:
1283+
report: The evaluation report to render
1284+
with_title: Whether to include the title in the table (default True)
1285+
1286+
Returns:
1287+
A Rich Table object
1288+
"""
11941289
case_renderer = self._get_case_renderer(report)
1195-
table = case_renderer.build_base_table(f'Evaluation Summary: {report.name}')
1290+
1291+
title = f'Evaluation Summary: {report.name}' if with_title else ''
1292+
table = case_renderer.build_base_table(title)
1293+
11961294
for case in report.cases:
11971295
table.add_row(*case_renderer.build_row(case))
11981296

@@ -1203,7 +1301,20 @@ def build_table(self, report: EvaluationReport) -> Table:
12031301

12041302
return table
12051303

1206-
def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport) -> Table:
1304+
# TODO(DavidM): in v2, change the return type here to RenderableType
1305+
def build_diff_table(
1306+
self, report: EvaluationReport, baseline: EvaluationReport, *, with_title: bool = True
1307+
) -> Table:
1308+
"""Build a diff table comparing report to baseline.
1309+
1310+
Args:
1311+
report: The evaluation report to compare
1312+
baseline: The baseline report to compare against
1313+
with_title: Whether to include the title in the table (default True)
1314+
1315+
Returns:
1316+
A Rich Table object
1317+
"""
12071318
report_cases = report.cases
12081319
baseline_cases = self._baseline_cases_to_include(report, baseline)
12091320

@@ -1228,7 +1339,10 @@ def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport)
12281339

12291340
case_renderer = self._get_case_renderer(report, baseline)
12301341
diff_name = baseline.name if baseline.name == report.name else f'{baseline.name}{report.name}'
1231-
table = case_renderer.build_base_table(f'Evaluation Diff: {diff_name}')
1342+
1343+
title = f'Evaluation Diff: {diff_name}' if with_title else ''
1344+
table = case_renderer.build_base_table(title)
1345+
12321346
for baseline_case, new_case in diff_cases:
12331347
table.add_row(*case_renderer.build_diff_row(new_case, baseline_case))
12341348
for case in added_cases:
@@ -1247,6 +1361,7 @@ def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport)
12471361

12481362
return table
12491363

1364+
# TODO(DavidM): in v2, change the return type here to RenderableType
12501365
def build_failures_table(self, report: EvaluationReport) -> Table:
12511366
case_renderer = self._get_case_renderer(report)
12521367
table = case_renderer.build_failures_table('Case Failures')

0 commit comments

Comments
 (0)