|
1 | 1 | from __future__ import annotations as _annotations
|
2 | 2 |
|
3 |
| -import asyncio |
4 | 3 | import json
|
5 | 4 | import sys
|
6 | 5 | from dataclasses import dataclass
|
7 | 6 | from pathlib import Path
|
8 | 7 | from typing import Any
|
9 | 8 |
|
10 | 9 | import pytest
|
11 |
| -from dirty_equals import HasRepr |
| 10 | +from dirty_equals import HasRepr, IsNumber |
12 | 11 | from inline_snapshot import snapshot
|
13 | 12 | from pydantic import BaseModel
|
14 | 13 |
|
@@ -178,21 +177,21 @@ def evaluate(self, ctx: EvaluatorContext[TaskInput, TaskOutput, TaskMetadata]):
|
178 | 177 | }
|
179 | 178 |
|
180 | 179 |
|
181 |
| -async def test_evaluate( |
| 180 | +async def test_evaluate_async( |
182 | 181 | example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata],
|
183 | 182 | simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]],
|
184 | 183 | ):
|
185 | 184 | """Test evaluating a dataset."""
|
186 | 185 | example_dataset.add_evaluator(simple_evaluator())
|
187 | 186 |
|
188 |
| - async def mock_task(inputs: TaskInput) -> TaskOutput: |
| 187 | + async def mock_async_task(inputs: TaskInput) -> TaskOutput: |
189 | 188 | if inputs.query == 'What is 2+2?':
|
190 | 189 | return TaskOutput(answer='4')
|
191 | 190 | elif inputs.query == 'What is the capital of France?':
|
192 | 191 | return TaskOutput(answer='Paris')
|
193 | 192 | return TaskOutput(answer='Unknown') # pragma: no cover
|
194 | 193 |
|
195 |
| - report = await example_dataset.evaluate(mock_task) |
| 194 | + report = await example_dataset.evaluate(mock_async_task) |
196 | 195 |
|
197 | 196 | assert report is not None
|
198 | 197 | assert len(report.cases) == 2
|
@@ -230,6 +229,58 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:
|
230 | 229 | )
|
231 | 230 |
|
232 | 231 |
|
| 232 | +async def test_evaluate_sync( |
| 233 | + example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], |
| 234 | + simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]], |
| 235 | +): |
| 236 | + """Test evaluating a dataset.""" |
| 237 | + example_dataset.add_evaluator(simple_evaluator()) |
| 238 | + |
| 239 | + def mock_sync_task(inputs: TaskInput) -> TaskOutput: |
| 240 | + if inputs.query == 'What is 2+2?': |
| 241 | + return TaskOutput(answer='4') |
| 242 | + elif inputs.query == 'What is the capital of France?': |
| 243 | + return TaskOutput(answer='Paris') |
| 244 | + return TaskOutput(answer='Unknown') # pragma: no cover |
| 245 | + |
| 246 | + report = await example_dataset.evaluate(mock_sync_task) |
| 247 | + |
| 248 | + assert report is not None |
| 249 | + assert len(report.cases) == 2 |
| 250 | + assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot( |
| 251 | + { |
| 252 | + 'assertions': { |
| 253 | + 'correct': { |
| 254 | + 'name': 'correct', |
| 255 | + 'reason': None, |
| 256 | + 'source': {'name': 'SimpleEvaluator', 'arguments': None}, |
| 257 | + 'value': True, |
| 258 | + } |
| 259 | + }, |
| 260 | + 'attributes': {}, |
| 261 | + 'expected_output': {'answer': '4', 'confidence': 1.0}, |
| 262 | + 'inputs': {'query': 'What is 2+2?'}, |
| 263 | + 'labels': {}, |
| 264 | + 'metadata': {'category': 'general', 'difficulty': 'easy'}, |
| 265 | + 'metrics': {}, |
| 266 | + 'name': 'case1', |
| 267 | + 'output': {'answer': '4', 'confidence': 1.0}, |
| 268 | + 'scores': { |
| 269 | + 'confidence': { |
| 270 | + 'name': 'confidence', |
| 271 | + 'reason': None, |
| 272 | + 'source': {'name': 'SimpleEvaluator', 'arguments': None}, |
| 273 | + 'value': 1.0, |
| 274 | + } |
| 275 | + }, |
| 276 | + 'span_id': '0000000000000003', |
| 277 | + 'task_duration': IsNumber(), # the runtime behavior is not deterministic due to threading |
| 278 | + 'total_duration': IsNumber(), # the runtime behavior is not deterministic due to threading |
| 279 | + 'trace_id': '00000000000000000000000000000001', |
| 280 | + } |
| 281 | + ) |
| 282 | + |
| 283 | + |
233 | 284 | async def test_evaluate_with_concurrency(
|
234 | 285 | example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata],
|
235 | 286 | simple_evaluator: type[Evaluator[TaskInput, TaskOutput, TaskMetadata]],
|
@@ -828,8 +879,8 @@ async def test_dataset_evaluate_with_sync_task(example_dataset: Dataset[TaskInpu
|
828 | 879 | def sync_task(inputs: TaskInput) -> TaskOutput:
|
829 | 880 | return TaskOutput(answer=inputs.query.upper())
|
830 | 881 |
|
831 |
| - report = await example_dataset.evaluate(lambda x: asyncio.sleep(0, sync_task(x))) |
832 |
| - assert report.name == '<lambda>' |
| 882 | + report = await example_dataset.evaluate(sync_task) |
| 883 | + assert report.name == 'sync_task' |
833 | 884 | assert len(report.cases) == 2
|
834 | 885 |
|
835 | 886 |
|
|
0 commit comments