Skip to content

Commit fe30db7

Browse files
committed
refactor(zero_shot_evaluation): restructure module layout and add new components
- Move core modules to top-level directory (checkpoint, schema, query_generator, response_collector, zero_shot_pipeline) - Add pairwise_analyzer to openjudge/analyzer - Add rubric_generator to openjudge/generator - Update imports and exports in __init__.py files - Simplify core/__init__.py to re-export from parent module
1 parent 31ca6ad commit fe30db7

File tree

17 files changed

+1784
-434
lines changed

17 files changed

+1784
-434
lines changed

cookbooks/zero_shot_evaluation/__init__.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,47 @@
11
# -*- coding: utf-8 -*-
2-
"""Zero-Shot Evaluation module for comparing models and agent pipelines.
2+
"""Core modules for zero-shot evaluation.
33
4-
Usage:
5-
# CLI
6-
python -m cookbooks.zero_shot_evaluation --config config.yaml
4+
This package contains the core components for the zero-shot evaluation pipeline:
5+
- ZeroShotPipeline: End-to-end evaluation pipeline
6+
- QueryGenerator: Test query generation
7+
- ResponseCollector: Response collection from endpoints
78
8-
# Python
9-
from cookbooks.zero_shot_evaluation import ZeroShotEvaluator
10-
evaluator = ZeroShotEvaluator.from_config("config.yaml")
11-
result = await evaluator.evaluate()
9+
Note: RubricGenerator has been moved to openjudge.generator module for better reusability.
10+
Note: Checkpoint management is integrated into ZeroShotPipeline.
1211
"""
1312

14-
from cookbooks.zero_shot_evaluation.core.config import load_config
15-
from cookbooks.zero_shot_evaluation.core.evaluator import EvaluationResult, ZeroShotEvaluator
16-
from cookbooks.zero_shot_evaluation.core.query_generator import QueryGenerator
17-
from cookbooks.zero_shot_evaluation.core.response_collector import ResponseCollector
18-
from cookbooks.zero_shot_evaluation.core.rubric_generator import RubricGenerator
19-
from cookbooks.zero_shot_evaluation.core.schema import (
13+
from cookbooks.zero_shot_evaluation.query_generator import QueryGenerator
14+
from cookbooks.zero_shot_evaluation.response_collector import ResponseCollector
15+
from cookbooks.zero_shot_evaluation.schema import (
2016
EvaluationConfig,
17+
GeneratedQuery,
2118
OpenAIEndpoint,
2219
QueryGenerationConfig,
2320
TaskConfig,
2421
ZeroShotConfig,
22+
load_config,
23+
)
24+
from cookbooks.zero_shot_evaluation.zero_shot_pipeline import (
25+
EvaluationResult,
26+
EvaluationStage,
27+
ZeroShotPipeline,
2528
)
2629

2730
__all__ = [
2831
# Config
29-
"ZeroShotConfig",
30-
"TaskConfig",
31-
"OpenAIEndpoint",
32-
"QueryGenerationConfig",
33-
"EvaluationConfig",
3432
"load_config",
33+
# Pipeline
34+
"ZeroShotPipeline",
35+
"EvaluationResult",
36+
"EvaluationStage",
3537
# Components
3638
"QueryGenerator",
3739
"ResponseCollector",
38-
"RubricGenerator",
39-
# Evaluator
40-
"ZeroShotEvaluator",
41-
"EvaluationResult",
40+
# Schema
41+
"EvaluationConfig",
42+
"GeneratedQuery",
43+
"OpenAIEndpoint",
44+
"QueryGenerationConfig",
45+
"TaskConfig",
46+
"ZeroShotConfig",
4247
]
43-

cookbooks/zero_shot_evaluation/__main__.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import fire
1616
from loguru import logger
1717

18-
from cookbooks.zero_shot_evaluation.core.config import load_config
19-
from cookbooks.zero_shot_evaluation.core.evaluator import ZeroShotEvaluator
20-
from cookbooks.zero_shot_evaluation.core.schema import GeneratedQuery
18+
from cookbooks.zero_shot_evaluation.schema import GeneratedQuery, load_config
19+
from cookbooks.zero_shot_evaluation.zero_shot_pipeline import ZeroShotPipeline
2120

2221

2322
def _load_queries_from_file(queries_file: str) -> List[GeneratedQuery]:
@@ -55,11 +54,11 @@ async def _run_evaluation(
5554
if queries_file:
5655
queries = _load_queries_from_file(queries_file)
5756

58-
evaluator = ZeroShotEvaluator(config=config, resume=resume)
59-
result = await evaluator.evaluate(queries=queries)
57+
pipeline = ZeroShotPipeline(config=config, resume=resume)
58+
result = await pipeline.evaluate(queries=queries)
6059

6160
if save:
62-
evaluator.save_results(result, output_dir)
61+
pipeline.save_results(result, output_dir)
6362

6463

6564
def main(
@@ -81,10 +80,10 @@ def main(
8180
Examples:
8281
# Normal run (auto-resumes from checkpoint)
8382
python -m cookbooks.zero_shot_evaluation --config config.yaml --save
84-
83+
8584
# Use pre-generated queries
8685
python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save
87-
86+
8887
# Start fresh, ignore checkpoint
8988
python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save
9089
"""
@@ -106,10 +105,9 @@ def main(
106105
logger.info("Starting fresh (ignoring checkpoint)")
107106
else:
108107
logger.info("Resume mode enabled (will continue from checkpoint if exists)")
109-
108+
110109
asyncio.run(_run_evaluation(str(config_path), output_dir, queries_file, save, resume=not fresh))
111110

112111

113112
if __name__ == "__main__":
114113
fire.Fire(main)
115-

cookbooks/zero_shot_evaluation/core/checkpoint.py renamed to cookbooks/zero_shot_evaluation/checkpoint.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from loguru import logger
1111
from pydantic import BaseModel, Field
1212

13-
from cookbooks.zero_shot_evaluation.core.schema import GeneratedQuery
13+
from cookbooks.zero_shot_evaluation.schema import GeneratedQuery
1414

1515

1616
class EvaluationStage(str, Enum):
1717
"""Evaluation pipeline stages."""
18-
18+
1919
NOT_STARTED = "not_started"
2020
QUERIES_GENERATED = "queries_generated"
2121
RESPONSES_COLLECTED = "responses_collected"
@@ -25,16 +25,16 @@ class EvaluationStage(str, Enum):
2525

2626
class CheckpointData(BaseModel):
2727
"""Checkpoint data model."""
28-
28+
2929
stage: EvaluationStage = Field(default=EvaluationStage.NOT_STARTED)
3030
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
3131
updated_at: str = Field(default_factory=lambda: datetime.now().isoformat())
32-
32+
3333
# Data files
3434
queries_file: Optional[str] = None
3535
responses_file: Optional[str] = None
3636
rubrics_file: Optional[str] = None
37-
37+
3838
# Progress tracking
3939
total_queries: int = 0
4040
collected_responses: int = 0
@@ -44,32 +44,32 @@ class CheckpointData(BaseModel):
4444

4545
class CheckpointManager:
4646
"""Manage evaluation checkpoints for resume capability."""
47-
47+
4848
CHECKPOINT_FILE = "checkpoint.json"
4949
QUERIES_FILE = "queries.json"
5050
RESPONSES_FILE = "responses.json"
5151
RUBRICS_FILE = "rubrics.json"
52-
52+
5353
def __init__(self, output_dir: str):
5454
"""Initialize checkpoint manager.
55-
55+
5656
Args:
5757
output_dir: Directory to store checkpoint files
5858
"""
5959
self.output_dir = Path(output_dir)
6060
self.output_dir.mkdir(parents=True, exist_ok=True)
6161
self._checkpoint: Optional[CheckpointData] = None
62-
62+
6363
@property
6464
def checkpoint_path(self) -> Path:
6565
return self.output_dir / self.CHECKPOINT_FILE
66-
66+
6767
def load(self) -> Optional[CheckpointData]:
6868
"""Load existing checkpoint if available."""
6969
if not self.checkpoint_path.exists():
7070
logger.info("No checkpoint found, starting fresh")
7171
return None
72-
72+
7373
try:
7474
with open(self.checkpoint_path, "r", encoding="utf-8") as f:
7575
data = json.load(f)
@@ -79,87 +79,87 @@ def load(self) -> Optional[CheckpointData]:
7979
except Exception as e:
8080
logger.warning(f"Failed to load checkpoint: {e}")
8181
return None
82-
82+
8383
def save(self, checkpoint: CheckpointData) -> None:
8484
"""Save checkpoint to file."""
8585
checkpoint.updated_at = datetime.now().isoformat()
8686
self._checkpoint = checkpoint
87-
87+
8888
with open(self.checkpoint_path, "w", encoding="utf-8") as f:
8989
json.dump(checkpoint.model_dump(), f, indent=2, ensure_ascii=False)
90-
90+
9191
logger.debug(f"Checkpoint saved: stage={checkpoint.stage.value}")
92-
92+
9393
def save_queries(self, queries: List[GeneratedQuery]) -> str:
9494
"""Save generated queries."""
9595
file_path = self.output_dir / self.QUERIES_FILE
96-
96+
9797
with open(file_path, "w", encoding="utf-8") as f:
9898
json.dump([q.model_dump() for q in queries], f, indent=2, ensure_ascii=False)
99-
99+
100100
logger.info(f"Saved {len(queries)} queries to {file_path}")
101101
return str(file_path)
102-
102+
103103
def load_queries(self) -> List[GeneratedQuery]:
104104
"""Load saved queries."""
105105
file_path = self.output_dir / self.QUERIES_FILE
106-
106+
107107
if not file_path.exists():
108108
return []
109-
109+
110110
with open(file_path, "r", encoding="utf-8") as f:
111111
data = json.load(f)
112-
112+
113113
queries = [GeneratedQuery(**item) for item in data]
114114
logger.info(f"Loaded {len(queries)} queries from {file_path}")
115115
return queries
116-
116+
117117
def save_responses(self, responses: List[Dict[str, Any]]) -> str:
118118
"""Save collected responses."""
119119
file_path = self.output_dir / self.RESPONSES_FILE
120-
120+
121121
with open(file_path, "w", encoding="utf-8") as f:
122122
json.dump(responses, f, indent=2, ensure_ascii=False)
123-
123+
124124
logger.info(f"Saved {len(responses)} responses to {file_path}")
125125
return str(file_path)
126-
126+
127127
def load_responses(self) -> List[Dict[str, Any]]:
128128
"""Load saved responses."""
129129
file_path = self.output_dir / self.RESPONSES_FILE
130-
130+
131131
if not file_path.exists():
132132
return []
133-
133+
134134
with open(file_path, "r", encoding="utf-8") as f:
135135
responses = json.load(f)
136-
136+
137137
logger.info(f"Loaded {len(responses)} responses from {file_path}")
138138
return responses
139-
139+
140140
def save_rubrics(self, rubrics: List[str]) -> str:
141141
"""Save generated rubrics."""
142142
file_path = self.output_dir / self.RUBRICS_FILE
143-
143+
144144
with open(file_path, "w", encoding="utf-8") as f:
145145
json.dump(rubrics, f, indent=2, ensure_ascii=False)
146-
146+
147147
logger.info(f"Saved {len(rubrics)} rubrics to {file_path}")
148148
return str(file_path)
149-
149+
150150
def load_rubrics(self) -> List[str]:
151151
"""Load saved rubrics."""
152152
file_path = self.output_dir / self.RUBRICS_FILE
153-
153+
154154
if not file_path.exists():
155155
return []
156-
156+
157157
with open(file_path, "r", encoding="utf-8") as f:
158158
rubrics = json.load(f)
159-
159+
160160
logger.info(f"Loaded {len(rubrics)} rubrics from {file_path}")
161161
return rubrics
162-
162+
163163
def update_stage(
164164
self,
165165
stage: EvaluationStage,
@@ -168,14 +168,14 @@ def update_stage(
168168
"""Update checkpoint stage and save."""
169169
if self._checkpoint is None:
170170
self._checkpoint = CheckpointData()
171-
171+
172172
self._checkpoint.stage = stage
173173
for key, value in kwargs.items():
174174
if hasattr(self._checkpoint, key):
175175
setattr(self._checkpoint, key, value)
176-
176+
177177
self.save(self._checkpoint)
178-
178+
179179
def clear(self) -> None:
180180
"""Clear all checkpoint data."""
181181
for file_name in [
@@ -187,7 +187,6 @@ def clear(self) -> None:
187187
file_path = self.output_dir / file_name
188188
if file_path.exists():
189189
file_path.unlink()
190-
190+
191191
self._checkpoint = None
192192
logger.info("Checkpoint cleared")
193-

0 commit comments

Comments
 (0)