13
13
import inspect
14
14
import sys
15
15
import time
16
+ import traceback
16
17
import warnings
17
18
from collections .abc import Awaitable , Callable , Mapping , Sequence
18
19
from contextlib import AsyncExitStack , nullcontext
19
20
from contextvars import ContextVar
20
21
from dataclasses import dataclass , field
21
22
from inspect import iscoroutinefunction
22
23
from pathlib import Path
23
- from typing import Any , Generic , Literal , Union , cast
24
+ from typing import TYPE_CHECKING , Any , Generic , Literal , Union , cast
24
25
25
26
import anyio
26
27
import logfire_api
40
41
from .evaluators ._run_evaluator import run_evaluator
41
42
from .evaluators .common import DEFAULT_EVALUATORS
42
43
from .evaluators .context import EvaluatorContext
44
+ from .evaluators .evaluator import EvaluatorFailure
43
45
from .evaluators .spec import EvaluatorSpec
44
46
from .otel import SpanTree
45
47
from .otel ._context_subtree import context_subtree
46
- from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate
48
+ from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate , ReportCaseFailure
49
+
50
+ if TYPE_CHECKING :
51
+ from pydantic_ai .retries import RetryConfig
47
52
48
53
if sys .version_info < (3 , 11 ):
49
54
from exceptiongroup import ExceptionGroup # pragma: lax no cover
74
79
75
80
76
81
_REPORT_CASES_ADAPTER = TypeAdapter (list [ReportCase ])
82
+ _REPORT_CASE_FAILURES_ADAPTER = TypeAdapter (list [ReportCaseFailure ])
77
83
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter (ReportCaseAggregate )
78
84
79
85
@@ -248,6 +254,8 @@ async def evaluate(
248
254
name : str | None = None ,
249
255
max_concurrency : int | None = None ,
250
256
progress : bool = True ,
257
+ retry_task : RetryConfig | None = None ,
258
+ retry_evaluators : RetryConfig | None = None ,
251
259
) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
252
260
"""Evaluates the test cases in the dataset using the given task.
253
261
@@ -262,6 +270,8 @@ async def evaluate(
262
270
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
263
271
If None, all cases will be evaluated concurrently.
264
272
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
273
+ retry_task: Optional retry configuration for the task execution.
274
+ retry_evaluators: Optional retry configuration for evaluator execution.
265
275
266
276
Returns:
267
277
A report containing the results of the evaluation.
@@ -277,7 +287,9 @@ async def evaluate(
277
287
278
288
async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
279
289
async with limiter :
280
- result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators )
290
+ result = await _run_task_and_evaluators (
291
+ task , case , report_case_name , self .evaluators , retry_task , retry_evaluators
292
+ )
281
293
if progress_bar and task_id is not None : # pragma: no branch
282
294
progress_bar .update (task_id , advance = 1 )
283
295
return result
@@ -288,21 +300,35 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
288
300
else :
289
301
trace_id = f'{ context .trace_id :032x} '
290
302
span_id = f'{ context .span_id :016x} '
303
+ cases_and_failures = await task_group_gather (
304
+ [
305
+ lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
306
+ for i , case in enumerate (self .cases , 1 )
307
+ ]
308
+ )
309
+ cases : list [ReportCase ] = []
310
+ failures : list [ReportCaseFailure ] = []
311
+ for item in cases_and_failures :
312
+ if isinstance (item , ReportCase ):
313
+ cases .append (item )
314
+ else :
315
+ failures .append (item )
291
316
report = EvaluationReport (
292
317
name = name ,
293
- cases = await task_group_gather (
294
- [
295
- lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
296
- for i , case in enumerate (self .cases , 1 )
297
- ]
298
- ),
318
+ cases = cases ,
319
+ failures = failures ,
299
320
span_id = span_id ,
300
321
trace_id = trace_id ,
301
322
)
323
+ # TODO(DavidM): Address the following TODOs before V1...
302
324
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
303
325
eval_span .set_attribute ('cases' , _REPORT_CASES_ADAPTER .dump_python (report .cases ))
326
+ # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
327
+ eval_span .set_attribute ('failures' , _REPORT_CASE_FAILURES_ADAPTER .dump_python (report .failures ))
304
328
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
305
- eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (report .averages ()))
329
+ averages = report .averages ()
330
+ if averages :
331
+ eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (averages ))
306
332
return report
307
333
308
334
def evaluate_sync (
@@ -810,38 +836,53 @@ def record_attribute(self, name: str, value: Any) -> None:
810
836
811
837
812
838
async def _run_task (
813
- task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ], case : Case [InputsT , OutputT , MetadataT ]
839
+ task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ],
840
+ case : Case [InputsT , OutputT , MetadataT ],
841
+ retry : RetryConfig | None = None ,
814
842
) -> EvaluatorContext [InputsT , OutputT , MetadataT ]:
815
843
"""Run a task on a case and return the context for evaluators.
816
844
817
845
Args:
818
846
task: The task to run.
819
847
case: The case to run the task on.
848
+ retry: The retry config to use.
820
849
821
850
Returns:
822
851
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
823
852
824
853
Raises:
825
854
Exception: Any exception raised by the task.
826
855
"""
827
- task_run = _TaskRun ()
828
- if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
829
- raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
830
856
831
- # Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
832
- # Should we handle them for the user in some way? If so, I guess we'd want to do that here.
833
- token = _CURRENT_TASK_RUN .set (task_run )
834
- try :
835
- with _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span :
836
- with context_subtree () as span_tree :
857
+ async def _run_once ():
858
+ task_run_ = _TaskRun ()
859
+ if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
860
+ raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
861
+
862
+ token = _CURRENT_TASK_RUN .set (task_run_ )
863
+ try :
864
+ with (
865
+ _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span ,
866
+ context_subtree () as span_tree_ ,
867
+ ):
837
868
t0 = time .perf_counter ()
838
869
if iscoroutinefunction (task ):
839
- task_output = cast (OutputT , await task (case .inputs ))
870
+ task_output_ = cast (OutputT , await task (case .inputs ))
840
871
else :
841
- task_output = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
872
+ task_output_ = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
842
873
fallback_duration = time .perf_counter () - t0
843
- finally :
844
- _CURRENT_TASK_RUN .reset (token )
874
+ duration_ = _get_span_duration (task_span , fallback_duration )
875
+ return task_run_ , task_output_ , duration_ , span_tree_
876
+ finally :
877
+ _CURRENT_TASK_RUN .reset (token )
878
+
879
+ if retry :
880
+ # import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
881
+ from pydantic_ai .retries import retry as tenacity_retry
882
+
883
+ _run_once = tenacity_retry (** retry )(_run_once )
884
+
885
+ task_run , task_output , duration , span_tree = await _run_once ()
845
886
846
887
if isinstance (span_tree , SpanTree ): # pragma: no branch
847
888
# Idea for making this more configurable: replace the following logic with a call to a user-provided function
@@ -865,7 +906,7 @@ async def _run_task(
865
906
metadata = case .metadata ,
866
907
expected_output = case .expected_output ,
867
908
output = task_output ,
868
- duration = _get_span_duration ( task_span , fallback_duration ) ,
909
+ duration = duration ,
869
910
_span_tree = span_tree ,
870
911
attributes = task_run .attributes ,
871
912
metrics = task_run .metrics ,
@@ -877,72 +918,93 @@ async def _run_task_and_evaluators(
877
918
case : Case [InputsT , OutputT , MetadataT ],
878
919
report_case_name : str ,
879
920
dataset_evaluators : list [Evaluator [InputsT , OutputT , MetadataT ]],
880
- ) -> ReportCase [InputsT , OutputT , MetadataT ]:
921
+ retry_task : RetryConfig | None ,
922
+ retry_evaluators : RetryConfig | None ,
923
+ ) -> ReportCase [InputsT , OutputT , MetadataT ] | ReportCaseFailure [InputsT , OutputT , MetadataT ]:
881
924
"""Run a task on a case and evaluate the results.
882
925
883
926
Args:
884
927
task: The task to run.
885
928
case: The case to run the task on.
886
929
report_case_name: The name to use for this case in the report.
887
930
dataset_evaluators: Evaluators from the dataset to apply to this case.
931
+ retry_task: The retry config to use for running the task.
932
+ retry_evaluators: The retry config to use for running the evaluators.
888
933
889
934
Returns:
890
935
A ReportCase containing the evaluation results.
891
936
"""
892
- with _logfire .span (
893
- 'case: {case_name}' ,
894
- task_name = get_unwrapped_function_name (task ),
895
- case_name = report_case_name ,
896
- inputs = case .inputs ,
897
- metadata = case .metadata ,
898
- expected_output = case .expected_output ,
899
- ) as case_span :
900
- t0 = time .time ()
901
- scoring_context = await _run_task (task , case )
902
-
903
- case_span .set_attribute ('output' , scoring_context .output )
904
- case_span .set_attribute ('task_duration' , scoring_context .duration )
905
- case_span .set_attribute ('metrics' , scoring_context .metrics )
906
- case_span .set_attribute ('attributes' , scoring_context .attributes )
907
-
908
- evaluators = case .evaluators + dataset_evaluators
909
- evaluator_outputs : list [EvaluationResult ] = []
910
- if evaluators :
911
- evaluator_outputs_by_task = await task_group_gather (
912
- [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
913
- )
914
- evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs ]
915
-
916
- assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
917
- case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
918
- case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
919
- case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
937
+ trace_id : str | None = None
938
+ span_id : str | None = None
939
+ try :
940
+ with _logfire .span (
941
+ 'case: {case_name}' ,
942
+ task_name = get_unwrapped_function_name (task ),
943
+ case_name = report_case_name ,
944
+ inputs = case .inputs ,
945
+ metadata = case .metadata ,
946
+ expected_output = case .expected_output ,
947
+ ) as case_span :
948
+ context = case_span .context
949
+ if context is not None : # pragma: no branch
950
+ trace_id = f'{ context .trace_id :032x} '
951
+ span_id = f'{ context .span_id :016x} '
920
952
921
- context = case_span .context
922
- if context is None : # pragma: no cover
923
- trace_id = None
924
- span_id = None
925
- else :
926
- trace_id = f'{ context .trace_id :032x} '
927
- span_id = f'{ context .span_id :016x} '
953
+ t0 = time .time ()
954
+ scoring_context = await _run_task (task , case , retry_task )
955
+
956
+ case_span .set_attribute ('output' , scoring_context .output )
957
+ case_span .set_attribute ('task_duration' , scoring_context .duration )
958
+ case_span .set_attribute ('metrics' , scoring_context .metrics )
959
+ case_span .set_attribute ('attributes' , scoring_context .attributes )
960
+
961
+ evaluators = case .evaluators + dataset_evaluators
962
+ evaluator_outputs : list [EvaluationResult ] = []
963
+ evaluator_failures : list [EvaluatorFailure ] = []
964
+ if evaluators :
965
+ evaluator_outputs_by_task = await task_group_gather (
966
+ [lambda ev = ev : run_evaluator (ev , scoring_context , retry_evaluators ) for ev in evaluators ]
967
+ )
968
+ for outputs in evaluator_outputs_by_task :
969
+ if isinstance (outputs , EvaluatorFailure ):
970
+ evaluator_failures .append (outputs )
971
+ else :
972
+ evaluator_outputs .extend (outputs )
973
+
974
+ assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
975
+ case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
976
+ case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
977
+ case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
928
978
fallback_duration = time .time () - t0
929
979
930
- return ReportCase [InputsT , OutputT , MetadataT ](
931
- name = report_case_name ,
932
- inputs = case .inputs ,
933
- metadata = case .metadata ,
934
- expected_output = case .expected_output ,
935
- output = scoring_context .output ,
936
- metrics = scoring_context .metrics ,
937
- attributes = scoring_context .attributes ,
938
- scores = scores ,
939
- labels = labels ,
940
- assertions = assertions ,
941
- task_duration = scoring_context .duration ,
942
- total_duration = _get_span_duration (case_span , fallback_duration ),
943
- trace_id = trace_id ,
944
- span_id = span_id ,
945
- )
980
+ return ReportCase [InputsT , OutputT , MetadataT ](
981
+ name = report_case_name ,
982
+ inputs = case .inputs ,
983
+ metadata = case .metadata ,
984
+ expected_output = case .expected_output ,
985
+ output = scoring_context .output ,
986
+ metrics = scoring_context .metrics ,
987
+ attributes = scoring_context .attributes ,
988
+ scores = scores ,
989
+ labels = labels ,
990
+ assertions = assertions ,
991
+ task_duration = scoring_context .duration ,
992
+ total_duration = _get_span_duration (case_span , fallback_duration ),
993
+ trace_id = trace_id ,
994
+ span_id = span_id ,
995
+ evaluator_failures = evaluator_failures ,
996
+ )
997
+ except Exception as exc :
998
+ return ReportCaseFailure [InputsT , OutputT , MetadataT ](
999
+ name = report_case_name ,
1000
+ inputs = case .inputs ,
1001
+ metadata = case .metadata ,
1002
+ expected_output = case .expected_output ,
1003
+ error_message = f'{ type (exc ).__name__ } : { exc } ' ,
1004
+ error_stacktrace = traceback .format_exc (),
1005
+ trace_id = trace_id ,
1006
+ span_id = span_id ,
1007
+ )
946
1008
947
1009
948
1010
_evaluation_results_adapter = TypeAdapter (Mapping [str , EvaluationResult ])
0 commit comments