Skip to content

Commit 1f3d6d7

Browse files
committed
feat: add zero-shot evaluation cookbook
- Add query generator for generating evaluation queries - Add response collector for collecting model responses - Add rubric generator for creating evaluation rubrics - Add evaluator for running evaluations - Add checkpoint support for resumable evaluations - Add example configurations
1 parent ac0a508 commit 1f3d6d7

File tree

10 files changed

+2159
-0
lines changed

10 files changed

+2159
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
"""Zero-Shot Evaluation module for comparing models and agent pipelines.
3+
4+
Usage:
5+
# CLI
6+
python -m cookbooks.zero_shot_evaluation --config config.yaml
7+
8+
# Python
9+
from cookbooks.zero_shot_evaluation import ZeroShotEvaluator
10+
evaluator = ZeroShotEvaluator.from_config("config.yaml")
11+
result = await evaluator.evaluate()
12+
"""
13+
14+
from cookbooks.zero_shot_evaluation.core.config import load_config
15+
from cookbooks.zero_shot_evaluation.core.evaluator import EvaluationResult, ZeroShotEvaluator
16+
from cookbooks.zero_shot_evaluation.core.query_generator import QueryGenerator
17+
from cookbooks.zero_shot_evaluation.core.response_collector import ResponseCollector
18+
from cookbooks.zero_shot_evaluation.core.rubric_generator import RubricGenerator
19+
from cookbooks.zero_shot_evaluation.core.schema import (
20+
EvaluationConfig,
21+
OpenAIEndpoint,
22+
QueryGenerationConfig,
23+
TaskConfig,
24+
ZeroShotConfig,
25+
)
26+
27+
__all__ = [
28+
# Config
29+
"ZeroShotConfig",
30+
"TaskConfig",
31+
"OpenAIEndpoint",
32+
"QueryGenerationConfig",
33+
"EvaluationConfig",
34+
"load_config",
35+
# Components
36+
"QueryGenerator",
37+
"ResponseCollector",
38+
"RubricGenerator",
39+
# Evaluator
40+
"ZeroShotEvaluator",
41+
"EvaluationResult",
42+
]
43+
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# -*- coding: utf-8 -*-
2+
"""CLI entry point for zero-shot evaluation.
3+
4+
Usage:
5+
python -m cookbooks.zero_shot_evaluation --config config.yaml
6+
python -m cookbooks.zero_shot_evaluation --config config.yaml --save
7+
python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save
8+
"""
9+
10+
import asyncio
11+
import json
12+
from pathlib import Path
13+
from typing import List, Optional
14+
15+
import fire
16+
from loguru import logger
17+
18+
from cookbooks.zero_shot_evaluation.core.config import load_config
19+
from cookbooks.zero_shot_evaluation.core.evaluator import ZeroShotEvaluator
20+
from cookbooks.zero_shot_evaluation.core.schema import GeneratedQuery
21+
22+
23+
def _load_queries_from_file(queries_file: str) -> List[GeneratedQuery]:
24+
"""Load pre-generated queries from JSON file."""
25+
with open(queries_file, "r", encoding="utf-8") as f:
26+
data = json.load(f)
27+
queries = [GeneratedQuery(**item) for item in data]
28+
logger.info(f"Loaded {len(queries)} queries from {queries_file}")
29+
return queries
30+
31+
32+
async def _run_evaluation(
33+
config_path: str,
34+
output_dir: Optional[str] = None,
35+
queries_file: Optional[str] = None,
36+
save: bool = False,
37+
resume: bool = True,
38+
) -> None:
39+
"""Run evaluation pipeline.
40+
41+
Args:
42+
config_path: Path to configuration file
43+
output_dir: Output directory (overrides config)
44+
queries_file: Path to pre-generated queries JSON file (skip generation)
45+
save: Whether to save results to file
46+
resume: Whether to resume from checkpoint
47+
"""
48+
config = load_config(config_path)
49+
50+
if output_dir:
51+
config.output.output_dir = output_dir
52+
53+
# Load pre-generated queries if provided
54+
queries = None
55+
if queries_file:
56+
queries = _load_queries_from_file(queries_file)
57+
58+
evaluator = ZeroShotEvaluator(config=config, resume=resume)
59+
result = await evaluator.evaluate(queries=queries)
60+
61+
if save:
62+
evaluator.save_results(result, output_dir)
63+
64+
65+
def main(
66+
config: str,
67+
output_dir: Optional[str] = None,
68+
queries_file: Optional[str] = None,
69+
save: bool = False,
70+
fresh: bool = False,
71+
) -> None:
72+
"""Zero-shot evaluation CLI with checkpoint support.
73+
74+
Args:
75+
config: Path to YAML configuration file
76+
output_dir: Output directory for results
77+
queries_file: Path to pre-generated queries JSON (skip query generation)
78+
save: Whether to save results to file
79+
fresh: Start fresh, ignore any existing checkpoint
80+
81+
Examples:
82+
# Normal run (auto-resumes from checkpoint)
83+
python -m cookbooks.zero_shot_evaluation --config config.yaml --save
84+
85+
# Use pre-generated queries
86+
python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save
87+
88+
# Start fresh, ignore checkpoint
89+
python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save
90+
"""
91+
config_path = Path(config)
92+
if not config_path.exists():
93+
logger.error(f"Config file not found: {config}")
94+
return
95+
96+
if queries_file:
97+
queries_path = Path(queries_file)
98+
if not queries_path.exists():
99+
logger.error(f"Queries file not found: {queries_file}")
100+
return
101+
102+
logger.info(f"Starting zero-shot evaluation with config: {config}")
103+
if queries_file:
104+
logger.info(f"Using pre-generated queries from: {queries_file}")
105+
if fresh:
106+
logger.info("Starting fresh (ignoring checkpoint)")
107+
else:
108+
logger.info("Resume mode enabled (will continue from checkpoint if exists)")
109+
110+
asyncio.run(_run_evaluation(str(config_path), output_dir, queries_file, save, resume=not fresh))
111+
112+
113+
if __name__ == "__main__":
114+
fire.Fire(main)
115+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
"""Core modules for zero-shot evaluation."""
3+
4+
from cookbooks.zero_shot_evaluation.core.checkpoint import (
5+
CheckpointManager,
6+
EvaluationStage,
7+
)
8+
from cookbooks.zero_shot_evaluation.core.config import load_config
9+
from cookbooks.zero_shot_evaluation.core.evaluator import EvaluationResult, ZeroShotEvaluator
10+
from cookbooks.zero_shot_evaluation.core.query_generator import QueryGenerator
11+
from cookbooks.zero_shot_evaluation.core.response_collector import ResponseCollector
12+
from cookbooks.zero_shot_evaluation.core.rubric_generator import RubricGenerator
13+
from cookbooks.zero_shot_evaluation.core.schema import (
14+
EvaluationConfig,
15+
GeneratedQuery,
16+
OpenAIEndpoint,
17+
QueryGenerationConfig,
18+
TaskConfig,
19+
ZeroShotConfig,
20+
)
21+
22+
__all__ = [
23+
# Checkpoint
24+
"CheckpointManager",
25+
"EvaluationStage",
26+
# Config
27+
"load_config",
28+
# Evaluator
29+
"ZeroShotEvaluator",
30+
"EvaluationResult",
31+
# Components
32+
"QueryGenerator",
33+
"ResponseCollector",
34+
"RubricGenerator",
35+
# Schema
36+
"EvaluationConfig",
37+
"GeneratedQuery",
38+
"OpenAIEndpoint",
39+
"QueryGenerationConfig",
40+
"TaskConfig",
41+
"ZeroShotConfig",
42+
]
43+
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# -*- coding: utf-8 -*-
2+
"""Checkpoint management for evaluation pipeline."""
3+
4+
import json
5+
from datetime import datetime
6+
from enum import Enum
7+
from pathlib import Path
8+
from typing import Any, Dict, List, Optional
9+
10+
from loguru import logger
11+
from pydantic import BaseModel, Field
12+
13+
from cookbooks.zero_shot_evaluation.core.schema import GeneratedQuery
14+
15+
16+
class EvaluationStage(str, Enum):
17+
"""Evaluation pipeline stages."""
18+
19+
NOT_STARTED = "not_started"
20+
QUERIES_GENERATED = "queries_generated"
21+
RESPONSES_COLLECTED = "responses_collected"
22+
RUBRICS_GENERATED = "rubrics_generated"
23+
EVALUATION_COMPLETE = "evaluation_complete"
24+
25+
26+
class CheckpointData(BaseModel):
27+
"""Checkpoint data model."""
28+
29+
stage: EvaluationStage = Field(default=EvaluationStage.NOT_STARTED)
30+
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
31+
updated_at: str = Field(default_factory=lambda: datetime.now().isoformat())
32+
33+
# Data files
34+
queries_file: Optional[str] = None
35+
responses_file: Optional[str] = None
36+
rubrics_file: Optional[str] = None
37+
38+
# Progress tracking
39+
total_queries: int = 0
40+
collected_responses: int = 0
41+
evaluated_pairs: int = 0
42+
total_pairs: int = 0
43+
44+
45+
class CheckpointManager:
46+
"""Manage evaluation checkpoints for resume capability."""
47+
48+
CHECKPOINT_FILE = "checkpoint.json"
49+
QUERIES_FILE = "queries.json"
50+
RESPONSES_FILE = "responses.json"
51+
RUBRICS_FILE = "rubrics.json"
52+
53+
def __init__(self, output_dir: str):
54+
"""Initialize checkpoint manager.
55+
56+
Args:
57+
output_dir: Directory to store checkpoint files
58+
"""
59+
self.output_dir = Path(output_dir)
60+
self.output_dir.mkdir(parents=True, exist_ok=True)
61+
self._checkpoint: Optional[CheckpointData] = None
62+
63+
@property
64+
def checkpoint_path(self) -> Path:
65+
return self.output_dir / self.CHECKPOINT_FILE
66+
67+
def load(self) -> Optional[CheckpointData]:
68+
"""Load existing checkpoint if available."""
69+
if not self.checkpoint_path.exists():
70+
logger.info("No checkpoint found, starting fresh")
71+
return None
72+
73+
try:
74+
with open(self.checkpoint_path, "r", encoding="utf-8") as f:
75+
data = json.load(f)
76+
self._checkpoint = CheckpointData(**data)
77+
logger.info(f"Loaded checkpoint: stage={self._checkpoint.stage.value}")
78+
return self._checkpoint
79+
except Exception as e:
80+
logger.warning(f"Failed to load checkpoint: {e}")
81+
return None
82+
83+
def save(self, checkpoint: CheckpointData) -> None:
84+
"""Save checkpoint to file."""
85+
checkpoint.updated_at = datetime.now().isoformat()
86+
self._checkpoint = checkpoint
87+
88+
with open(self.checkpoint_path, "w", encoding="utf-8") as f:
89+
json.dump(checkpoint.model_dump(), f, indent=2, ensure_ascii=False)
90+
91+
logger.debug(f"Checkpoint saved: stage={checkpoint.stage.value}")
92+
93+
def save_queries(self, queries: List[GeneratedQuery]) -> str:
94+
"""Save generated queries."""
95+
file_path = self.output_dir / self.QUERIES_FILE
96+
97+
with open(file_path, "w", encoding="utf-8") as f:
98+
json.dump([q.model_dump() for q in queries], f, indent=2, ensure_ascii=False)
99+
100+
logger.info(f"Saved {len(queries)} queries to {file_path}")
101+
return str(file_path)
102+
103+
def load_queries(self) -> List[GeneratedQuery]:
104+
"""Load saved queries."""
105+
file_path = self.output_dir / self.QUERIES_FILE
106+
107+
if not file_path.exists():
108+
return []
109+
110+
with open(file_path, "r", encoding="utf-8") as f:
111+
data = json.load(f)
112+
113+
queries = [GeneratedQuery(**item) for item in data]
114+
logger.info(f"Loaded {len(queries)} queries from {file_path}")
115+
return queries
116+
117+
def save_responses(self, responses: List[Dict[str, Any]]) -> str:
118+
"""Save collected responses."""
119+
file_path = self.output_dir / self.RESPONSES_FILE
120+
121+
with open(file_path, "w", encoding="utf-8") as f:
122+
json.dump(responses, f, indent=2, ensure_ascii=False)
123+
124+
logger.info(f"Saved {len(responses)} responses to {file_path}")
125+
return str(file_path)
126+
127+
def load_responses(self) -> List[Dict[str, Any]]:
128+
"""Load saved responses."""
129+
file_path = self.output_dir / self.RESPONSES_FILE
130+
131+
if not file_path.exists():
132+
return []
133+
134+
with open(file_path, "r", encoding="utf-8") as f:
135+
responses = json.load(f)
136+
137+
logger.info(f"Loaded {len(responses)} responses from {file_path}")
138+
return responses
139+
140+
def save_rubrics(self, rubrics: List[str]) -> str:
141+
"""Save generated rubrics."""
142+
file_path = self.output_dir / self.RUBRICS_FILE
143+
144+
with open(file_path, "w", encoding="utf-8") as f:
145+
json.dump(rubrics, f, indent=2, ensure_ascii=False)
146+
147+
logger.info(f"Saved {len(rubrics)} rubrics to {file_path}")
148+
return str(file_path)
149+
150+
def load_rubrics(self) -> List[str]:
151+
"""Load saved rubrics."""
152+
file_path = self.output_dir / self.RUBRICS_FILE
153+
154+
if not file_path.exists():
155+
return []
156+
157+
with open(file_path, "r", encoding="utf-8") as f:
158+
rubrics = json.load(f)
159+
160+
logger.info(f"Loaded {len(rubrics)} rubrics from {file_path}")
161+
return rubrics
162+
163+
def update_stage(
164+
self,
165+
stage: EvaluationStage,
166+
**kwargs,
167+
) -> None:
168+
"""Update checkpoint stage and save."""
169+
if self._checkpoint is None:
170+
self._checkpoint = CheckpointData()
171+
172+
self._checkpoint.stage = stage
173+
for key, value in kwargs.items():
174+
if hasattr(self._checkpoint, key):
175+
setattr(self._checkpoint, key, value)
176+
177+
self.save(self._checkpoint)
178+
179+
def clear(self) -> None:
180+
"""Clear all checkpoint data."""
181+
for file_name in [
182+
self.CHECKPOINT_FILE,
183+
self.QUERIES_FILE,
184+
self.RESPONSES_FILE,
185+
self.RUBRICS_FILE,
186+
]:
187+
file_path = self.output_dir / file_name
188+
if file_path.exists():
189+
file_path.unlink()
190+
191+
self._checkpoint = None
192+
logger.info("Checkpoint cleared")
193+

0 commit comments

Comments
 (0)