Skip to content

Commit 2c3a218

Browse files
committed
Record experiment metadata
1 parent f96dfe4 commit 2c3a218

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ async def evaluate(
265265
retry_evaluators: RetryConfig | None = None,
266266
*,
267267
task_name: str | None = None,
268+
metadata: dict[str, Any] | None = None,
269+
tags: Sequence[str] | None = None,
268270
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
269271
"""Evaluates the test cases in the dataset using the given task.
270272
@@ -283,6 +285,8 @@ async def evaluate(
283285
retry_evaluators: Optional retry configuration for evaluator execution.
284286
task_name: Optional override to the name of the task being executed, otherwise the name of the task
285287
function will be used.
288+
metadata: Optional dict of experiment metadata.
289+
tags: Optional sequence of logfire tags.
286290
287291
Returns:
288292
A report containing the results of the evaluation.
@@ -294,14 +298,19 @@ async def evaluate(
294298

295299
limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
296300

301+
extra_attributes: dict[str, Any] = {'gen_ai.operation.name': 'experiment'}
302+
if metadata is not None:
303+
extra_attributes['metadata'] = metadata
297304
with (
298305
logfire_span(
299306
'evaluate {name}',
300307
name=name,
301308
task_name=task_name,
302309
dataset_name=self.name,
303310
n_cases=len(self.cases),
304-
**{'gen_ai.operation.name': 'experiment'}, # pyright: ignore[reportArgumentType]
311+
metadata=metadata,
312+
**extra_attributes,
313+
_tags=tags,
305314
) as eval_span,
306315
progress_bar or nullcontext(),
307316
):
@@ -342,10 +351,16 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
342351
span_id=span_id,
343352
trace_id=trace_id,
344353
)
345-
if (averages := report.averages()) is not None and averages.assertions is not None:
346-
experiment_metadata = {'n_cases': len(self.cases), 'averages': averages}
347-
eval_span.set_attribute('logfire.experiment.metadata', experiment_metadata)
348-
eval_span.set_attribute('assertion_pass_rate', averages.assertions)
354+
full_experiment_metadata: dict[str, Any] = {'n_cases': len(self.cases)}
355+
if metadata is not None:
356+
full_experiment_metadata['metadata'] = metadata
357+
if tags is not None:
358+
full_experiment_metadata['tags'] = tags
359+
if (averages := report.averages()) is not None:
360+
full_experiment_metadata['averages'] = averages
361+
if averages.assertions is not None:
362+
eval_span.set_attribute('assertion_pass_rate', averages.assertions)
363+
eval_span.set_attribute('logfire.experiment.metadata', full_experiment_metadata)
349364
return report
350365

351366
def evaluate_sync(

0 commit comments

Comments
 (0)