Skip to content

Commit 6add458

Browse files
authored
Support evaluating sync tasks (#2150)
1 parent 11d1cde commit 6add458

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
from contextlib import AsyncExitStack, nullcontext
1919
from contextvars import ContextVar
2020
from dataclasses import dataclass, field
21+
from inspect import iscoroutinefunction
2122
from pathlib import Path
2223
from typing import Any, Callable, Generic, Literal, Union, cast
2324

2425
import anyio
2526
import logfire_api
2627
import yaml
28+
from anyio import to_thread
2729
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_serializer
2830
from pydantic._internal import _typing_extra
2931
from pydantic_core import to_json
@@ -253,7 +255,7 @@ def __init__(
253255

254256
async def evaluate(
255257
self,
256-
task: Callable[[InputsT], Awaitable[OutputT]],
258+
task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
257259
name: str | None = None,
258260
max_concurrency: int | None = None,
259261
progress: bool = True,
@@ -308,7 +310,7 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name
308310

309311
def evaluate_sync(
310312
self,
311-
task: Callable[[InputsT], Awaitable[OutputT]],
313+
task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
312314
name: str | None = None,
313315
max_concurrency: int | None = None,
314316
progress: bool = True,
@@ -811,7 +813,7 @@ def record_attribute(self, name: str, value: Any) -> None:
811813

812814

813815
async def _run_task(
814-
task: Callable[[InputsT], Awaitable[OutputT]], case: Case[InputsT, OutputT, MetadataT]
816+
task: Callable[[InputsT], Awaitable[OutputT] | OutputT], case: Case[InputsT, OutputT, MetadataT]
815817
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
816818
"""Run a task on a case and return the context for evaluators.
817819
@@ -836,7 +838,10 @@ async def _run_task(
836838
with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span:
837839
with context_subtree() as span_tree:
838840
t0 = time.perf_counter()
839-
task_output = await task(case.inputs)
841+
if iscoroutinefunction(task):
842+
task_output = cast(OutputT, await task(case.inputs))
843+
else:
844+
task_output = cast(OutputT, await to_thread.run_sync(task, case.inputs))
840845
fallback_duration = time.perf_counter() - t0
841846
finally:
842847
_CURRENT_TASK_RUN.reset(token)
@@ -873,7 +878,7 @@ async def _run_task(
873878

874879

875880
async def _run_task_and_evaluators(
876-
task: Callable[[InputsT], Awaitable[OutputT]],
881+
task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT],
877882
case: Case[InputsT, OutputT, MetadataT],
878883
report_case_name: str,
879884
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],

tests/evals/test_dataset.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations as _annotations
22

3-
import asyncio
43
import json
54
import sys
65
from dataclasses import dataclass
76
from pathlib import Path
87
from typing import Any
98

109
import pytest
11-
from dirty_equals import HasRepr
10+
from dirty_equals import HasRepr, IsNumber
1211
from inline_snapshot import snapshot
1312
from pydantic import BaseModel
1413

@@ -178,21 +177,21 @@ def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
178177
}
179178

180179

181-
async def test_evaluate(
180+
async def test_evaluate_async(
182181
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata],
183182
simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]],
184183
):
185184
"""Test evaluating a dataset."""
186185
example_dataset.add_evaluator(simple_evaluator())
187186

188-
async def mock_task(inputs: TaskInput) -> TaskOutput:
187+
async def mock_async_task(inputs: TaskInput) -> TaskOutput:
189188
if inputs.query == 'What is 2+2?':
190189
return TaskOutput(answer='4')
191190
elif inputs.query == 'What is the capital of France?':
192191
return TaskOutput(answer='Paris')
193192
return TaskOutput(answer='Unknown') # pragma: no cover
194193

195-
report = await example_dataset.evaluate(mock_task)
194+
report = await example_dataset.evaluate(mock_async_task)
196195

197196
assert report is not None
198197
assert len(report.cases) == 2
@@ -230,6 +229,58 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:
230229
)
231230

232231

232+
async def test_evaluate_sync(
233+
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata],
234+
simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]],
235+
):
236+
"""Test evaluating a dataset."""
237+
example_dataset.add_evaluator(simple_evaluator())
238+
239+
def mock_sync_task(inputs: TaskInput) -> TaskOutput:
240+
if inputs.query == 'What is 2+2?':
241+
return TaskOutput(answer='4')
242+
elif inputs.query == 'What is the capital of France?':
243+
return TaskOutput(answer='Paris')
244+
return TaskOutput(answer='Unknown') # pragma: no cover
245+
246+
report = await example_dataset.evaluate(mock_sync_task)
247+
248+
assert report is not None
249+
assert len(report.cases) == 2
250+
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
251+
{
252+
'assertions': {
253+
'correct': {
254+
'name': 'correct',
255+
'reason': None,
256+
'source': {'name': 'SimpleEvaluator', 'arguments': None},
257+
'value': True,
258+
}
259+
},
260+
'attributes': {},
261+
'expected_output': {'answer': '4', 'confidence': 1.0},
262+
'inputs': {'query': 'What is 2+2?'},
263+
'labels': {},
264+
'metadata': {'category': 'general', 'difficulty': 'easy'},
265+
'metrics': {},
266+
'name': 'case1',
267+
'output': {'answer': '4', 'confidence': 1.0},
268+
'scores': {
269+
'confidence': {
270+
'name': 'confidence',
271+
'reason': None,
272+
'source': {'name': 'SimpleEvaluator', 'arguments': None},
273+
'value': 1.0,
274+
}
275+
},
276+
'span_id': '0000000000000003',
277+
'task_duration': IsNumber(), # the runtime behavior is not deterministic due to threading
278+
'total_duration': IsNumber(), # the runtime behavior is not deterministic due to threading
279+
'trace_id': '00000000000000000000000000000001',
280+
}
281+
)
282+
283+
233284
async def test_evaluate_with_concurrency(
234285
example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata],
235286
simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]],
@@ -828,8 +879,8 @@ async def test_dataset_evaluate_with_sync_task(example_dataset: Dataset[TaskInpu
828879
def sync_task(inputs: TaskInput) -> TaskOutput:
829880
return TaskOutput(answer=inputs.query.upper())
830881

831-
report = await example_dataset.evaluate(lambda x: asyncio.sleep(0, sync_task(x)))
832-
assert report.name == '<lambda>'
882+
report = await example_dataset.evaluate(sync_task)
883+
assert report.name == 'sync_task'
833884
assert len(report.cases) == 2
834885

835886

0 commit comments

Comments
 (0)