@@ -98,6 +98,7 @@ class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forb
98
98
99
99
# $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context
100
100
json_schema_path : str | None = Field (default = None , alias = '$schema' )
101
+ name : str | None = None
101
102
cases : list [_CaseModel [InputsT , OutputT , MetadataT ]]
102
103
evaluators : list [EvaluatorSpec ] = Field (default_factory = list )
103
104
@@ -218,6 +219,8 @@ async def main():
218
219
```
219
220
"""
220
221
222
+ name : str | None = None
223
+ """Optional name of the dataset."""
221
224
cases : list [Case [InputsT , OutputT , MetadataT ]]
222
225
"""List of test cases in the dataset."""
223
226
evaluators : list [Evaluator [InputsT , OutputT , MetadataT ]] = []
@@ -226,12 +229,14 @@ async def main():
226
229
def __init__ (
227
230
self ,
228
231
* ,
232
+ name : str | None = None ,
229
233
cases : Sequence [Case [InputsT , OutputT , MetadataT ]],
230
234
evaluators : Sequence [Evaluator [InputsT , OutputT , MetadataT ]] = (),
231
235
):
232
236
"""Initialize a new dataset with test cases and optional evaluators.
233
237
234
238
Args:
239
+ name: Optional name for the dataset.
235
240
cases: Sequence of test cases to include in the dataset.
236
241
evaluators: Optional sequence of evaluators to apply to all cases in the dataset.
237
242
"""
@@ -244,10 +249,12 @@ def __init__(
244
249
case_names .add (case .name )
245
250
246
251
super ().__init__ (
252
+ name = name ,
247
253
cases = cases ,
248
254
evaluators = list (evaluators ),
249
255
)
250
256
257
+ # TODO in v2: Make everything not required keyword-only
251
258
async def evaluate (
252
259
self ,
253
260
task : Callable [[InputsT ], Awaitable [OutputT ]] | Callable [[InputsT ], OutputT ],
@@ -256,6 +263,8 @@ async def evaluate(
256
263
progress : bool = True ,
257
264
retry_task : RetryConfig | None = None ,
258
265
retry_evaluators : RetryConfig | None = None ,
266
+ * ,
267
+ task_name : str | None = None ,
259
268
) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
260
269
"""Evaluates the test cases in the dataset using the given task.
261
270
@@ -265,28 +274,38 @@ async def evaluate(
265
274
Args:
266
275
task: The task to evaluate. This should be a callable that takes the inputs of the case
267
276
and returns the output.
268
- name: The name of the task being evaluated , this is used to identify the task in the report.
269
- If omitted, the name of the task function will be used.
277
+ name: The name of the experiment being run , this is used to identify the experiment in the report.
278
+ If omitted, the task_name will be used; if that is not specified, the name of the task function is used.
270
279
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
271
280
If None, all cases will be evaluated concurrently.
272
281
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
273
282
retry_task: Optional retry configuration for the task execution.
274
283
retry_evaluators: Optional retry configuration for evaluator execution.
284
+ task_name: Optional override to the name of the task being executed, otherwise the name of the task
285
+ function will be used.
275
286
276
287
Returns:
277
288
A report containing the results of the evaluation.
278
289
"""
279
- name = name or get_unwrapped_function_name (task )
290
+ task_name = task_name or get_unwrapped_function_name (task )
291
+ name = name or task_name
280
292
total_cases = len (self .cases )
281
293
progress_bar = Progress () if progress else None
282
294
283
295
limiter = anyio .Semaphore (max_concurrency ) if max_concurrency is not None else AsyncExitStack ()
284
296
285
297
with (
286
- logfire_span ('evaluate {name}' , name = name , n_cases = len (self .cases )) as eval_span ,
298
+ logfire_span (
299
+ 'evaluate {name}' ,
300
+ name = name ,
301
+ task_name = task_name ,
302
+ dataset_name = self .name ,
303
+ n_cases = len (self .cases ),
304
+ ** {'gen_ai.operation.name' : 'experiment' }, # pyright: ignore[reportArgumentType]
305
+ ) as eval_span ,
287
306
progress_bar or nullcontext (),
288
307
):
289
- task_id = progress_bar .add_task (f'Evaluating { name } ' , total = total_cases ) if progress_bar else None
308
+ task_id = progress_bar .add_task (f'Evaluating { task_name } ' , total = total_cases ) if progress_bar else None
290
309
291
310
async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
292
311
async with limiter :
@@ -357,7 +376,7 @@ def evaluate_sync(
357
376
return get_event_loop ().run_until_complete (
358
377
self .evaluate (
359
378
task ,
360
- name = name ,
379
+ task_name = name ,
361
380
max_concurrency = max_concurrency ,
362
381
progress = progress ,
363
382
retry_task = retry_task ,
@@ -474,7 +493,7 @@ def from_file(
474
493
475
494
raw = Path (path ).read_text ()
476
495
try :
477
- return cls .from_text (raw , fmt = fmt , custom_evaluator_types = custom_evaluator_types )
496
+ return cls .from_text (raw , fmt = fmt , custom_evaluator_types = custom_evaluator_types , default_name = path . stem )
478
497
except ValidationError as e : # pragma: no cover
479
498
raise ValueError (f'{ path } contains data that does not match the schema for { cls .__name__ } :\n { e } .' ) from e
480
499
@@ -484,6 +503,8 @@ def from_text(
484
503
contents : str ,
485
504
fmt : Literal ['yaml' , 'json' ] = 'yaml' ,
486
505
custom_evaluator_types : Sequence [type [Evaluator [InputsT , OutputT , MetadataT ]]] = (),
506
+ * ,
507
+ default_name : str | None = None ,
487
508
) -> Self :
488
509
"""Load a dataset from a string.
489
510
@@ -492,6 +513,7 @@ def from_text(
492
513
fmt: Format of the content. Must be either 'yaml' or 'json'.
493
514
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
494
515
These are additional evaluators beyond the default ones.
516
+ default_name: Default name of the dataset, to be used if not specified in the serialized contents.
495
517
496
518
Returns:
497
519
A new Dataset instance parsed from the string.
@@ -501,24 +523,27 @@ def from_text(
501
523
"""
502
524
if fmt == 'yaml' :
503
525
loaded = yaml .safe_load (contents )
504
- return cls .from_dict (loaded , custom_evaluator_types )
526
+ return cls .from_dict (loaded , custom_evaluator_types , default_name = default_name )
505
527
else :
506
528
dataset_model_type = cls ._serialization_type ()
507
529
dataset_model = dataset_model_type .model_validate_json (contents )
508
- return cls ._from_dataset_model (dataset_model , custom_evaluator_types )
530
+ return cls ._from_dataset_model (dataset_model , custom_evaluator_types , default_name )
509
531
510
532
@classmethod
511
533
def from_dict (
512
534
cls ,
513
535
data : dict [str , Any ],
514
536
custom_evaluator_types : Sequence [type [Evaluator [InputsT , OutputT , MetadataT ]]] = (),
537
+ * ,
538
+ default_name : str | None = None ,
515
539
) -> Self :
516
540
"""Load a dataset from a dictionary.
517
541
518
542
Args:
519
543
data: Dictionary representation of the dataset.
520
544
custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
521
545
These are additional evaluators beyond the default ones.
546
+ default_name: Default name of the dataset, to be used if not specified in the data.
522
547
523
548
Returns:
524
549
A new Dataset instance created from the dictionary.
@@ -528,19 +553,21 @@ def from_dict(
528
553
"""
529
554
dataset_model_type = cls ._serialization_type ()
530
555
dataset_model = dataset_model_type .model_validate (data )
531
- return cls ._from_dataset_model (dataset_model , custom_evaluator_types )
556
+ return cls ._from_dataset_model (dataset_model , custom_evaluator_types , default_name )
532
557
533
558
@classmethod
534
559
def _from_dataset_model (
535
560
cls ,
536
561
dataset_model : _DatasetModel [InputsT , OutputT , MetadataT ],
537
562
custom_evaluator_types : Sequence [type [Evaluator [InputsT , OutputT , MetadataT ]]] = (),
563
+ default_name : str | None = None ,
538
564
) -> Self :
539
565
"""Create a Dataset from a _DatasetModel.
540
566
541
567
Args:
542
568
dataset_model: The _DatasetModel to convert.
543
569
custom_evaluator_types: Custom evaluator classes to register for deserialization.
570
+ default_name: Default name of the dataset, to be used if the value is `None` in the provided model.
544
571
545
572
Returns:
546
573
A new Dataset instance created from the _DatasetModel.
@@ -577,7 +604,9 @@ def _from_dataset_model(
577
604
cases .append (row )
578
605
if errors :
579
606
raise ExceptionGroup (f'{ len (errors )} error(s) loading evaluators from registry' , errors [:3 ])
580
- result = cls (cases = cases )
607
+ result = cls (name = dataset_model .name , cases = cases )
608
+ if result .name is None :
609
+ result .name = default_name
581
610
result .evaluators = dataset_evaluators
582
611
return result
583
612
0 commit comments