Skip to content

Commit 6762f06

Browse files
committed
Add the ability to write Evals
1 parent 3f12254 commit 6762f06

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from .evals import Eval, Metric, TestCase, TestOk, TestError
12
from .models import Model
23

34
__doc__ = ""
45

56
__all__ = [
7+
# from .evals
8+
"Eval",
9+
"Metric",
10+
"TestCase",
11+
"TestOk",
12+
"TestError",
613
# from .models
714
"Model"
815
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .eval import Eval, Metric, TestCase, TestOk, TestError
2+
3+
__all__ = [
4+
"Eval",
5+
"Metric",
6+
"TestCase",
7+
"TestOk",
8+
"TestError",
9+
]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from dataclasses import dataclass, field
2+
from typing import Protocol, Any, Callable, Dict, List, Optional, Union
3+
4+
5+
class Metric(Protocol):
6+
@property
7+
def name(self) -> str: ...
8+
9+
def calculate(self, expected_output: Any, observed_output: Any) -> float: ...
10+
11+
def aggregate(self, values: List[float]) -> float: ...
12+
13+
14+
@dataclass
15+
class TestCase:
16+
input: Any
17+
expected_output: Any
18+
19+
20+
@dataclass
21+
class TestOk:
22+
test_case: TestCase
23+
output: Any
24+
metrics: Dict[str, float] = field(default_factory=dict)
25+
26+
27+
@dataclass
28+
class TestError:
29+
test_case: TestCase
30+
error: str
31+
32+
33+
class Eval:
34+
def run(
35+
self, subject: Callable, cases: List[TestCase], metrics: List[Metric]
36+
) -> tuple[Dict[str, float], List[Union[TestOk, TestError]]]:
37+
if not callable(subject):
38+
raise ValueError("subject is not callable")
39+
40+
if not cases:
41+
raise ValueError("test case list is empty")
42+
43+
if not metrics:
44+
raise ValueError("list of metrics is empty")
45+
46+
results = []
47+
for case in cases:
48+
try:
49+
observed_output = subject(case.input)
50+
51+
# Calculate all metrics for this test case
52+
calculated = {}
53+
for m in metrics:
54+
calculated[m.name] = m.calculate(case.expected_output, observed_output)
55+
56+
results.append(TestOk(test_case=case, output=observed_output, metrics=calculated))
57+
except Exception as e:
58+
results.append(TestError(test_case=case, error=str(e)))
59+
60+
# Aggregate each metric across all test cases
61+
aggregated = {}
62+
for m in metrics:
63+
values = []
64+
for result in results:
65+
if isinstance(result, TestOk) and m.name in result.metrics:
66+
values.append(result.metrics[m.name])
67+
aggregated[m.name] = m.aggregate(values) if values else float("nan")
68+
69+
return aggregated, results

0 commit comments

Comments
 (0)