Skip to content

Commit 9edfd2c

Browse files
committed
Refactor evaluation script to use inspect_ai imports
Replaces local type stubs and placeholder functions in tests/inspect-ai/scripts/evaluation.py with direct imports from the inspect_ai package. Updates type annotations to use built-in generics and simplifies sample creation logic. Also adds the script to pyrightconfig.json's extraPaths for type checking.
1 parent 0d3eeb1 commit 9edfd2c

File tree

2 files changed

+9
-37
lines changed

2 files changed

+9
-37
lines changed

pyrightconfig.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"tests/playwright/ai_generated_apps/*/*/app*.py",
1414
"tests/inspect-ai/apps/*/app*.py",
1515
"shiny/pytest/_generate/_main.py",
16+
"tests/inspect-ai/scripts/evaluation.py"
1617
],
1718
"typeCheckingMode": "strict",
1819
"reportImportCycles": "none",

tests/inspect-ai/scripts/evaluation.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,11 @@
11
import json
22
from pathlib import Path
3-
from typing import Any, Callable, Dict, List
43

5-
6-
# Type stubs for inspect_ai imports
7-
class Task:
8-
def __init__(
9-
self, dataset: List[Any], solver: Any, scorer: Any, model: Any
10-
) -> None:
11-
self.dataset = dataset
12-
self.solver = solver
13-
self.scorer = scorer
14-
self.model = model
15-
16-
17-
def task(func: Callable[[], Task]) -> Callable[[], Task]:
18-
return func
19-
20-
21-
class Sample:
22-
def __init__(self, input: str, target: str, metadata: Dict[str, Any]) -> None:
23-
self.input = input
24-
self.target = target
25-
self.metadata = metadata
26-
27-
28-
def get_model(model_name: str) -> Any:
29-
pass
30-
31-
32-
def model_graded_qa(instructions: str, grade_pattern: str, model: Any) -> Any:
33-
pass
34-
35-
36-
def generate() -> Any:
37-
pass
4+
from inspect_ai import Task, task
5+
from inspect_ai.dataset import Sample
6+
from inspect_ai.model import get_model
7+
from inspect_ai.scorer import model_graded_qa
8+
from inspect_ai.solver import generate
389

3910

4011
def get_app_specific_instructions(app_name: str) -> str:
@@ -142,7 +113,7 @@ def get_app_specific_instructions(app_name: str) -> str:
142113
return app_instructions.get(app_name, "")
143114

144115

145-
def create_inspect_ai_samples(test_data: Dict[str, Dict[str, Any]]) -> List[Sample]:
116+
def create_inspect_ai_samples(test_data: dict) -> list[Sample]:
146117
"""
147118
Create Inspect AI samples from the generated test data.
148119
@@ -152,7 +123,7 @@ def create_inspect_ai_samples(test_data: Dict[str, Dict[str, Any]]) -> List[Samp
152123
Returns:
153124
List of Sample objects for Inspect AI evaluation
154125
"""
155-
samples: List[Sample] = []
126+
samples = []
156127

157128
for test_name, data in test_data.items():
158129
app_specific_guidance = get_app_specific_instructions(data["app_name"])
@@ -202,7 +173,7 @@ def shiny_test_evaluation() -> Task:
202173
with open(metadata_file, "r") as f:
203174
test_data = json.load(f)
204175

205-
samples: List[Sample] = create_inspect_ai_samples(test_data)
176+
samples = create_inspect_ai_samples(test_data)
206177

207178
scorer = model_graded_qa(
208179
instructions="""

0 commit comments

Comments
 (0)