Skip to content

Commit 9354f08

Browse files
authored
Gracefully handle errors in evals (#2295)
1 parent e854f98 commit 9354f08

File tree

11 files changed

+1328
-197
lines changed

11 files changed

+1328
-197
lines changed

examples/pydantic_ai_examples/evals/example_03_unit_testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def evaluate_dataset():
2929
report = dataset.evaluate_sync(infer_time_range)
3030
print(report)
3131

32-
assertion_pass_rate = report.averages().assertions
32+
averages = report.averages()
33+
assert averages is not None
34+
assertion_pass_rate = averages.assertions
3335
assert assertion_pass_rate is not None, 'There should be at least one assertion'
3436
assert assertion_pass_rate > 0.9, (
3537
f'The assertion pass rate was {assertion_pass_rate:.1%}; it should be above 90%.'

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 139 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
import inspect
1414
import sys
1515
import time
16+
import traceback
1617
import warnings
1718
from collections.abc import Awaitable, Callable, Mapping, Sequence
1819
from contextlib import AsyncExitStack, nullcontext
1920
from contextvars import ContextVar
2021
from dataclasses import dataclass, field
2122
from inspect import iscoroutinefunction
2223
from pathlib import Path
23-
from typing import Any, Generic, Literal, Union, cast
24+
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
2425

2526
import anyio
2627
import logfire_api
@@ -40,10 +41,14 @@
4041
from .evaluators._run_evaluator import run_evaluator
4142
from .evaluators.common import DEFAULT_EVALUATORS
4243
from .evaluators.context import EvaluatorContext
44+
from .evaluators.evaluator import EvaluatorFailure
4345
from .evaluators.spec import EvaluatorSpec
4446
from .otel import SpanTree
4547
from .otel._context_subtree import context_subtree
46-
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
48+
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
49+
50+
if TYPE_CHECKING:
51+
from pydantic_ai.retries import RetryConfig
4752

4853
if sys.version_info < (3, 11):
4954
from exceptiongroup import ExceptionGroup # pragma: lax no cover
@@ -74,6 +79,7 @@
7479

7580

7681
_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
82+
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
7783
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)
7884

7985

@@ -248,6 +254,8 @@ async def evaluate(
248254
name: str | None = None,
249255
max_concurrency: int | None = None,
250256
progress: bool = True,
257+
retry_task: RetryConfig | None = None,
258+
retry_evaluators: RetryConfig | None = None,
251259
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
252260
"""Evaluates the test cases in the dataset using the given task.
253261
@@ -262,6 +270,8 @@ async def evaluate(
262270
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
263271
If None, all cases will be evaluated concurrently.
264272
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
273+
retry_task: Optional retry configuration for the task execution.
274+
retry_evaluators: Optional retry configuration for evaluator execution.
265275
266276
Returns:
267277
A report containing the results of the evaluation.
@@ -277,7 +287,9 @@ async def evaluate(
277287

278288
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
279289
async with limiter:
280-
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
290+
result = await _run_task_and_evaluators(
291+
task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
292+
)
281293
if progress_bar and task_id is not None: # pragma: no branch
282294
progress_bar.update(task_id, advance=1)
283295
return result
@@ -288,21 +300,35 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
288300
else:
289301
trace_id = f'{context.trace_id:032x}'
290302
span_id = f'{context.span_id:016x}'
303+
cases_and_failures = await task_group_gather(
304+
[
305+
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
306+
for i, case in enumerate(self.cases, 1)
307+
]
308+
)
309+
cases: list[ReportCase] = []
310+
failures: list[ReportCaseFailure] = []
311+
for item in cases_and_failures:
312+
if isinstance(item, ReportCase):
313+
cases.append(item)
314+
else:
315+
failures.append(item)
291316
report = EvaluationReport(
292317
name=name,
293-
cases=await task_group_gather(
294-
[
295-
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
296-
for i, case in enumerate(self.cases, 1)
297-
]
298-
),
318+
cases=cases,
319+
failures=failures,
299320
span_id=span_id,
300321
trace_id=trace_id,
301322
)
323+
# TODO(DavidM): Address the following TODOs before V1...
302324
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
303325
eval_span.set_attribute('cases', _REPORT_CASES_ADAPTER.dump_python(report.cases))
326+
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
327+
eval_span.set_attribute('failures', _REPORT_CASE_FAILURES_ADAPTER.dump_python(report.failures))
304328
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
305-
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(report.averages()))
329+
averages = report.averages()
330+
if averages:
331+
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(averages))
306332
return report
307333

308334
def evaluate_sync(
@@ -810,38 +836,53 @@ def record_attribute(self, name: str, value: Any) -> None:
810836

811837

812838
async def _run_task(
813-
task: Callable[[InputsT], Awaitable[OutputT] | OutputT], case: Case[InputsT, OutputT, MetadataT]
839+
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
840+
case: Case[InputsT, OutputT, MetadataT],
841+
retry: RetryConfig | None = None,
814842
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
815843
"""Run a task on a case and return the context for evaluators.
816844
817845
Args:
818846
task: The task to run.
819847
case: The case to run the task on.
848+
retry: The retry config to use.
820849
821850
Returns:
822851
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
823852
824853
Raises:
825854
Exception: Any exception raised by the task.
826855
"""
827-
task_run = _TaskRun()
828-
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
829-
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
830856

831-
# Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
832-
# Should we handle them for the user in some way? If so, I guess we'd want to do that here.
833-
token = _CURRENT_TASK_RUN.set(task_run)
834-
try:
835-
with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span:
836-
with context_subtree() as span_tree:
857+
async def _run_once():
858+
task_run_ = _TaskRun()
859+
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
860+
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
861+
862+
token = _CURRENT_TASK_RUN.set(task_run_)
863+
try:
864+
with (
865+
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
866+
context_subtree() as span_tree_,
867+
):
837868
t0 = time.perf_counter()
838869
if iscoroutinefunction(task):
839-
task_output = cast(OutputT, await task(case.inputs))
870+
task_output_ = cast(OutputT, await task(case.inputs))
840871
else:
841-
task_output = cast(OutputT, await to_thread.run_sync(task, case.inputs))
872+
task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
842873
fallback_duration = time.perf_counter() - t0
843-
finally:
844-
_CURRENT_TASK_RUN.reset(token)
874+
duration_ = _get_span_duration(task_span, fallback_duration)
875+
return task_run_, task_output_, duration_, span_tree_
876+
finally:
877+
_CURRENT_TASK_RUN.reset(token)
878+
879+
if retry:
880+
# import from pydantic_ai.retries to trigger more descriptive import error if tenacity is missing
881+
from pydantic_ai.retries import retry as tenacity_retry
882+
883+
_run_once = tenacity_retry(**retry)(_run_once)
884+
885+
task_run, task_output, duration, span_tree = await _run_once()
845886

846887
if isinstance(span_tree, SpanTree): # pragma: no branch
847888
# Idea for making this more configurable: replace the following logic with a call to a user-provided function
@@ -865,7 +906,7 @@ async def _run_task(
865906
metadata=case.metadata,
866907
expected_output=case.expected_output,
867908
output=task_output,
868-
duration=_get_span_duration(task_span, fallback_duration),
909+
duration=duration,
869910
_span_tree=span_tree,
870911
attributes=task_run.attributes,
871912
metrics=task_run.metrics,
@@ -877,72 +918,93 @@ async def _run_task_and_evaluators(
877918
case: Case[InputsT, OutputT, MetadataT],
878919
report_case_name: str,
879920
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
880-
) -> ReportCase[InputsT, OutputT, MetadataT]:
921+
retry_task: RetryConfig | None,
922+
retry_evaluators: RetryConfig | None,
923+
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
881924
"""Run a task on a case and evaluate the results.
882925
883926
Args:
884927
task: The task to run.
885928
case: The case to run the task on.
886929
report_case_name: The name to use for this case in the report.
887930
dataset_evaluators: Evaluators from the dataset to apply to this case.
931+
retry_task: The retry config to use for running the task.
932+
retry_evaluators: The retry config to use for running the evaluators.
888933
889934
Returns:
890935
A ReportCase containing the evaluation results.
891936
"""
892-
with _logfire.span(
893-
'case: {case_name}',
894-
task_name=get_unwrapped_function_name(task),
895-
case_name=report_case_name,
896-
inputs=case.inputs,
897-
metadata=case.metadata,
898-
expected_output=case.expected_output,
899-
) as case_span:
900-
t0 = time.time()
901-
scoring_context = await _run_task(task, case)
902-
903-
case_span.set_attribute('output', scoring_context.output)
904-
case_span.set_attribute('task_duration', scoring_context.duration)
905-
case_span.set_attribute('metrics', scoring_context.metrics)
906-
case_span.set_attribute('attributes', scoring_context.attributes)
907-
908-
evaluators = case.evaluators + dataset_evaluators
909-
evaluator_outputs: list[EvaluationResult] = []
910-
if evaluators:
911-
evaluator_outputs_by_task = await task_group_gather(
912-
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
913-
)
914-
evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]
915-
916-
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
917-
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
918-
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
919-
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
937+
trace_id: str | None = None
938+
span_id: str | None = None
939+
try:
940+
with _logfire.span(
941+
'case: {case_name}',
942+
task_name=get_unwrapped_function_name(task),
943+
case_name=report_case_name,
944+
inputs=case.inputs,
945+
metadata=case.metadata,
946+
expected_output=case.expected_output,
947+
) as case_span:
948+
context = case_span.context
949+
if context is not None: # pragma: no branch
950+
trace_id = f'{context.trace_id:032x}'
951+
span_id = f'{context.span_id:016x}'
920952

921-
context = case_span.context
922-
if context is None: # pragma: no cover
923-
trace_id = None
924-
span_id = None
925-
else:
926-
trace_id = f'{context.trace_id:032x}'
927-
span_id = f'{context.span_id:016x}'
953+
t0 = time.time()
954+
scoring_context = await _run_task(task, case, retry_task)
955+
956+
case_span.set_attribute('output', scoring_context.output)
957+
case_span.set_attribute('task_duration', scoring_context.duration)
958+
case_span.set_attribute('metrics', scoring_context.metrics)
959+
case_span.set_attribute('attributes', scoring_context.attributes)
960+
961+
evaluators = case.evaluators + dataset_evaluators
962+
evaluator_outputs: list[EvaluationResult] = []
963+
evaluator_failures: list[EvaluatorFailure] = []
964+
if evaluators:
965+
evaluator_outputs_by_task = await task_group_gather(
966+
[lambda ev=ev: run_evaluator(ev, scoring_context, retry_evaluators) for ev in evaluators]
967+
)
968+
for outputs in evaluator_outputs_by_task:
969+
if isinstance(outputs, EvaluatorFailure):
970+
evaluator_failures.append(outputs)
971+
else:
972+
evaluator_outputs.extend(outputs)
973+
974+
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
975+
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
976+
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
977+
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
928978
fallback_duration = time.time() - t0
929979

930-
return ReportCase[InputsT, OutputT, MetadataT](
931-
name=report_case_name,
932-
inputs=case.inputs,
933-
metadata=case.metadata,
934-
expected_output=case.expected_output,
935-
output=scoring_context.output,
936-
metrics=scoring_context.metrics,
937-
attributes=scoring_context.attributes,
938-
scores=scores,
939-
labels=labels,
940-
assertions=assertions,
941-
task_duration=scoring_context.duration,
942-
total_duration=_get_span_duration(case_span, fallback_duration),
943-
trace_id=trace_id,
944-
span_id=span_id,
945-
)
980+
return ReportCase[InputsT, OutputT, MetadataT](
981+
name=report_case_name,
982+
inputs=case.inputs,
983+
metadata=case.metadata,
984+
expected_output=case.expected_output,
985+
output=scoring_context.output,
986+
metrics=scoring_context.metrics,
987+
attributes=scoring_context.attributes,
988+
scores=scores,
989+
labels=labels,
990+
assertions=assertions,
991+
task_duration=scoring_context.duration,
992+
total_duration=_get_span_duration(case_span, fallback_duration),
993+
trace_id=trace_id,
994+
span_id=span_id,
995+
evaluator_failures=evaluator_failures,
996+
)
997+
except Exception as exc:
998+
return ReportCaseFailure[InputsT, OutputT, MetadataT](
999+
name=report_case_name,
1000+
inputs=case.inputs,
1001+
metadata=case.metadata,
1002+
expected_output=case.expected_output,
1003+
error_message=f'{type(exc).__name__}: {exc}',
1004+
error_stacktrace=traceback.format_exc(),
1005+
trace_id=trace_id,
1006+
span_id=span_id,
1007+
)
9461008

9471009

9481010
_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])

pydantic_evals/pydantic_evals/evaluators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
Python,
1111
)
1212
from .context import EvaluatorContext
13-
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorOutput, EvaluatorSpec
13+
from .evaluator import EvaluationReason, EvaluationResult, Evaluator, EvaluatorFailure, EvaluatorOutput, EvaluatorSpec
1414

1515
__all__ = (
1616
# common
@@ -27,6 +27,8 @@
2727
'EvaluatorContext',
2828
# evaluator
2929
'Evaluator',
30+
'EvaluationReason',
31+
'EvaluatorFailure',
3032
'EvaluatorOutput',
3133
'EvaluatorSpec',
3234
'EvaluationReason',

0 commit comments

Comments
 (0)