@@ -265,6 +265,8 @@ async def evaluate(
265265 retry_evaluators : RetryConfig | None = None ,
266266 * ,
267267 task_name : str | None = None ,
268+ metadata : dict [str , Any ] | None = None ,
269+ tags : Sequence [str ] | None = None ,
268270 ) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
269271 """Evaluates the test cases in the dataset using the given task.
270272
@@ -283,6 +285,8 @@ async def evaluate(
283285 retry_evaluators: Optional retry configuration for evaluator execution.
284286 task_name: Optional override to the name of the task being executed, otherwise the name of the task
285287 function will be used.
288+ metadata: Optional dict of experiment metadata.
289+ tags: Optional sequence of logfire tags.
286290
287291 Returns:
288292 A report containing the results of the evaluation.
@@ -294,14 +298,19 @@ async def evaluate(
294298
295299 limiter = anyio .Semaphore (max_concurrency ) if max_concurrency is not None else AsyncExitStack ()
296300
301+ extra_attributes : dict [str , Any ] = {'gen_ai.operation.name' : 'experiment' }
302+ if metadata is not None :
303+ extra_attributes ['metadata' ] = metadata
297304 with (
298305 logfire_span (
299306 'evaluate {name}' ,
300307 name = name ,
301308 task_name = task_name ,
302309 dataset_name = self .name ,
303310 n_cases = len (self .cases ),
304- ** {'gen_ai.operation.name' : 'experiment' }, # pyright: ignore[reportArgumentType]
311+ metadata = metadata ,
312+ ** extra_attributes ,
313+ _tags = tags ,
305314 ) as eval_span ,
306315 progress_bar or nullcontext (),
307316 ):
@@ -342,10 +351,16 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
342351 span_id = span_id ,
343352 trace_id = trace_id ,
344353 )
345- if (averages := report .averages ()) is not None and averages .assertions is not None :
346- experiment_metadata = {'n_cases' : len (self .cases ), 'averages' : averages }
347- eval_span .set_attribute ('logfire.experiment.metadata' , experiment_metadata )
348- eval_span .set_attribute ('assertion_pass_rate' , averages .assertions )
354+ full_experiment_metadata : dict [str , Any ] = {'n_cases' : len (self .cases )}
355+ if metadata is not None :
356+ full_experiment_metadata ['metadata' ] = metadata
357+ if tags is not None :
358+ full_experiment_metadata ['tags' ] = tags
359+ if (averages := report .averages ()) is not None :
360+ full_experiment_metadata ['averages' ] = averages
361+ if averages .assertions is not None :
362+ eval_span .set_attribute ('assertion_pass_rate' , averages .assertions )
363+ eval_span .set_attribute ('logfire.experiment.metadata' , full_experiment_metadata )
349364 return report
350365
351366 def evaluate_sync (
0 commit comments