Skip to content

Commit cb230fe

Browse files
authored
Merge pull request #28 from modelscope/feature/zero-shot-evaluation
Feature/zero shot evaluation
2 parents 791ebdc + cdd37c3 commit cb230fe

File tree

15 files changed

+3544
-2
lines changed

15 files changed

+3544
-2
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# -*- coding: utf-8 -*-
2+
"""CLI entry point for zero-shot evaluation.
3+
4+
Usage:
5+
python -m cookbooks.zero_shot_evaluation --config config.yaml
6+
python -m cookbooks.zero_shot_evaluation --config config.yaml --save
7+
python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save
8+
"""
9+
10+
import asyncio
11+
import json
12+
from pathlib import Path
13+
from typing import List, Optional
14+
15+
import fire
16+
from loguru import logger
17+
18+
from cookbooks.zero_shot_evaluation.schema import GeneratedQuery, load_config
19+
from cookbooks.zero_shot_evaluation.zero_shot_pipeline import ZeroShotPipeline
20+
21+
22+
def _load_queries_from_file(queries_file: str) -> List[GeneratedQuery]:
23+
"""Load pre-generated queries from JSON file."""
24+
with open(queries_file, "r", encoding="utf-8") as f:
25+
data = json.load(f)
26+
queries = [GeneratedQuery(**item) for item in data]
27+
logger.info(f"Loaded {len(queries)} queries from {queries_file}")
28+
return queries
29+
30+
31+
async def _run_evaluation(
32+
config_path: str,
33+
output_dir: Optional[str] = None,
34+
queries_file: Optional[str] = None,
35+
save: bool = False,
36+
resume: bool = True,
37+
) -> None:
38+
"""Run evaluation pipeline.
39+
40+
Args:
41+
config_path: Path to configuration file
42+
output_dir: Output directory (overrides config)
43+
queries_file: Path to pre-generated queries JSON file (skip generation)
44+
save: Whether to save results to file
45+
resume: Whether to resume from checkpoint
46+
"""
47+
config = load_config(config_path)
48+
49+
if output_dir:
50+
config.output.output_dir = output_dir
51+
52+
# Load pre-generated queries if provided
53+
queries = None
54+
if queries_file:
55+
queries = _load_queries_from_file(queries_file)
56+
57+
pipeline = ZeroShotPipeline(config=config, resume=resume)
58+
result = await pipeline.evaluate(queries=queries)
59+
60+
if save:
61+
pipeline.save_results(result, output_dir)
62+
63+
64+
def main(
65+
config: str,
66+
output_dir: Optional[str] = None,
67+
queries_file: Optional[str] = None,
68+
save: bool = False,
69+
fresh: bool = False,
70+
) -> None:
71+
"""Zero-shot evaluation CLI with checkpoint support.
72+
73+
Args:
74+
config: Path to YAML configuration file
75+
output_dir: Output directory for results
76+
queries_file: Path to pre-generated queries JSON (skip query generation)
77+
save: Whether to save results to file
78+
fresh: Start fresh, ignore any existing checkpoint
79+
80+
Examples:
81+
# Normal run (auto-resumes from checkpoint)
82+
python -m cookbooks.zero_shot_evaluation --config config.yaml --save
83+
84+
# Use pre-generated queries
85+
python -m cookbooks.zero_shot_evaluation --config config.yaml --queries_file queries.json --save
86+
87+
# Start fresh, ignore checkpoint
88+
python -m cookbooks.zero_shot_evaluation --config config.yaml --fresh --save
89+
"""
90+
config_path = Path(config)
91+
if not config_path.exists():
92+
logger.error(f"Config file not found: {config}")
93+
return
94+
95+
if queries_file:
96+
queries_path = Path(queries_file)
97+
if not queries_path.exists():
98+
logger.error(f"Queries file not found: {queries_file}")
99+
return
100+
101+
logger.info(f"Starting zero-shot evaluation with config: {config}")
102+
if queries_file:
103+
logger.info(f"Using pre-generated queries from: {queries_file}")
104+
if fresh:
105+
logger.info("Starting fresh (ignoring checkpoint)")
106+
else:
107+
logger.info("Resume mode enabled (will continue from checkpoint if exists)")
108+
109+
asyncio.run(_run_evaluation(str(config_path), output_dir, queries_file, save, resume=not fresh))
110+
111+
112+
if __name__ == "__main__":
113+
fire.Fire(main)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# -*- coding: utf-8 -*-
2+
"""Checkpoint management for evaluation pipeline."""
3+
4+
import json
5+
from datetime import datetime
6+
from enum import Enum
7+
from pathlib import Path
8+
from typing import Any, Dict, List, Optional
9+
10+
from loguru import logger
11+
from pydantic import BaseModel, Field
12+
13+
from cookbooks.zero_shot_evaluation.schema import GeneratedQuery
14+
15+
16+
class EvaluationStage(str, Enum):
17+
"""Evaluation pipeline stages."""
18+
19+
NOT_STARTED = "not_started"
20+
QUERIES_GENERATED = "queries_generated"
21+
RESPONSES_COLLECTED = "responses_collected"
22+
RUBRICS_GENERATED = "rubrics_generated"
23+
EVALUATION_COMPLETE = "evaluation_complete"
24+
25+
26+
class CheckpointData(BaseModel):
27+
"""Checkpoint data model."""
28+
29+
stage: EvaluationStage = Field(default=EvaluationStage.NOT_STARTED)
30+
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
31+
updated_at: str = Field(default_factory=lambda: datetime.now().isoformat())
32+
33+
# Data files
34+
queries_file: Optional[str] = None
35+
responses_file: Optional[str] = None
36+
rubrics_file: Optional[str] = None
37+
38+
# Progress tracking
39+
total_queries: int = 0
40+
collected_responses: int = 0
41+
evaluated_pairs: int = 0
42+
total_pairs: int = 0
43+
44+
45+
class CheckpointManager:
46+
"""Manage evaluation checkpoints for resume capability."""
47+
48+
CHECKPOINT_FILE = "checkpoint.json"
49+
QUERIES_FILE = "queries.json"
50+
RESPONSES_FILE = "responses.json"
51+
RUBRICS_FILE = "rubrics.json"
52+
53+
def __init__(self, output_dir: str):
54+
"""Initialize checkpoint manager.
55+
56+
Args:
57+
output_dir: Directory to store checkpoint files
58+
"""
59+
self.output_dir = Path(output_dir)
60+
self.output_dir.mkdir(parents=True, exist_ok=True)
61+
self._checkpoint: Optional[CheckpointData] = None
62+
63+
@property
64+
def checkpoint_path(self) -> Path:
65+
return self.output_dir / self.CHECKPOINT_FILE
66+
67+
def load(self) -> Optional[CheckpointData]:
68+
"""Load existing checkpoint if available."""
69+
if not self.checkpoint_path.exists():
70+
logger.info("No checkpoint found, starting fresh")
71+
return None
72+
73+
try:
74+
with open(self.checkpoint_path, "r", encoding="utf-8") as f:
75+
data = json.load(f)
76+
self._checkpoint = CheckpointData(**data)
77+
logger.info(f"Loaded checkpoint: stage={self._checkpoint.stage.value}")
78+
return self._checkpoint
79+
except Exception as e:
80+
logger.warning(f"Failed to load checkpoint: {e}")
81+
return None
82+
83+
def save(self, checkpoint: CheckpointData) -> None:
84+
"""Save checkpoint to file."""
85+
checkpoint.updated_at = datetime.now().isoformat()
86+
self._checkpoint = checkpoint
87+
88+
with open(self.checkpoint_path, "w", encoding="utf-8") as f:
89+
json.dump(checkpoint.model_dump(), f, indent=2, ensure_ascii=False)
90+
91+
logger.debug(f"Checkpoint saved: stage={checkpoint.stage.value}")
92+
93+
def save_queries(self, queries: List[GeneratedQuery]) -> str:
94+
"""Save generated queries."""
95+
file_path = self.output_dir / self.QUERIES_FILE
96+
97+
with open(file_path, "w", encoding="utf-8") as f:
98+
json.dump([q.model_dump() for q in queries], f, indent=2, ensure_ascii=False)
99+
100+
logger.info(f"Saved {len(queries)} queries to {file_path}")
101+
return str(file_path)
102+
103+
def load_queries(self) -> List[GeneratedQuery]:
104+
"""Load saved queries."""
105+
file_path = self.output_dir / self.QUERIES_FILE
106+
107+
if not file_path.exists():
108+
return []
109+
110+
with open(file_path, "r", encoding="utf-8") as f:
111+
data = json.load(f)
112+
113+
queries = [GeneratedQuery(**item) for item in data]
114+
logger.info(f"Loaded {len(queries)} queries from {file_path}")
115+
return queries
116+
117+
def save_responses(self, responses: List[Dict[str, Any]]) -> str:
118+
"""Save collected responses."""
119+
file_path = self.output_dir / self.RESPONSES_FILE
120+
121+
with open(file_path, "w", encoding="utf-8") as f:
122+
json.dump(responses, f, indent=2, ensure_ascii=False)
123+
124+
logger.info(f"Saved {len(responses)} responses to {file_path}")
125+
return str(file_path)
126+
127+
def load_responses(self) -> List[Dict[str, Any]]:
128+
"""Load saved responses."""
129+
file_path = self.output_dir / self.RESPONSES_FILE
130+
131+
if not file_path.exists():
132+
return []
133+
134+
with open(file_path, "r", encoding="utf-8") as f:
135+
responses = json.load(f)
136+
137+
logger.info(f"Loaded {len(responses)} responses from {file_path}")
138+
return responses
139+
140+
def save_rubrics(self, rubrics: List[str]) -> str:
141+
"""Save generated rubrics."""
142+
file_path = self.output_dir / self.RUBRICS_FILE
143+
144+
with open(file_path, "w", encoding="utf-8") as f:
145+
json.dump(rubrics, f, indent=2, ensure_ascii=False)
146+
147+
logger.info(f"Saved {len(rubrics)} rubrics to {file_path}")
148+
return str(file_path)
149+
150+
def load_rubrics(self) -> List[str]:
151+
"""Load saved rubrics."""
152+
file_path = self.output_dir / self.RUBRICS_FILE
153+
154+
if not file_path.exists():
155+
return []
156+
157+
with open(file_path, "r", encoding="utf-8") as f:
158+
rubrics = json.load(f)
159+
160+
logger.info(f"Loaded {len(rubrics)} rubrics from {file_path}")
161+
return rubrics
162+
163+
def update_stage(
164+
self,
165+
stage: EvaluationStage,
166+
**kwargs,
167+
) -> None:
168+
"""Update checkpoint stage and save."""
169+
if self._checkpoint is None:
170+
self._checkpoint = CheckpointData()
171+
172+
self._checkpoint.stage = stage
173+
for key, value in kwargs.items():
174+
if hasattr(self._checkpoint, key):
175+
setattr(self._checkpoint, key, value)
176+
177+
self.save(self._checkpoint)
178+
179+
def clear(self) -> None:
180+
"""Clear all checkpoint data."""
181+
for file_name in [
182+
self.CHECKPOINT_FILE,
183+
self.QUERIES_FILE,
184+
self.RESPONSES_FILE,
185+
self.RUBRICS_FILE,
186+
]:
187+
file_path = self.output_dir / file_name
188+
if file_path.exists():
189+
file_path.unlink()
190+
191+
self._checkpoint = None
192+
logger.info("Checkpoint cleared")

0 commit comments

Comments
 (0)