1010from loguru import logger
1111from pydantic import BaseModel , Field
1212
13- from cookbooks .zero_shot_evaluation .core . schema import GeneratedQuery
13+ from cookbooks .zero_shot_evaluation .schema import GeneratedQuery
1414
1515
1616class EvaluationStage (str , Enum ):
1717 """Evaluation pipeline stages."""
18-
18+
1919 NOT_STARTED = "not_started"
2020 QUERIES_GENERATED = "queries_generated"
2121 RESPONSES_COLLECTED = "responses_collected"
@@ -25,16 +25,16 @@ class EvaluationStage(str, Enum):
2525
2626class CheckpointData (BaseModel ):
2727 """Checkpoint data model."""
28-
28+
2929 stage : EvaluationStage = Field (default = EvaluationStage .NOT_STARTED )
3030 created_at : str = Field (default_factory = lambda : datetime .now ().isoformat ())
3131 updated_at : str = Field (default_factory = lambda : datetime .now ().isoformat ())
32-
32+
3333 # Data files
3434 queries_file : Optional [str ] = None
3535 responses_file : Optional [str ] = None
3636 rubrics_file : Optional [str ] = None
37-
37+
3838 # Progress tracking
3939 total_queries : int = 0
4040 collected_responses : int = 0
@@ -44,32 +44,32 @@ class CheckpointData(BaseModel):
4444
4545class CheckpointManager :
4646 """Manage evaluation checkpoints for resume capability."""
47-
47+
4848 CHECKPOINT_FILE = "checkpoint.json"
4949 QUERIES_FILE = "queries.json"
5050 RESPONSES_FILE = "responses.json"
5151 RUBRICS_FILE = "rubrics.json"
52-
52+
5353 def __init__ (self , output_dir : str ):
5454 """Initialize checkpoint manager.
55-
55+
5656 Args:
5757 output_dir: Directory to store checkpoint files
5858 """
5959 self .output_dir = Path (output_dir )
6060 self .output_dir .mkdir (parents = True , exist_ok = True )
6161 self ._checkpoint : Optional [CheckpointData ] = None
62-
62+
6363 @property
6464 def checkpoint_path (self ) -> Path :
6565 return self .output_dir / self .CHECKPOINT_FILE
66-
66+
6767 def load (self ) -> Optional [CheckpointData ]:
6868 """Load existing checkpoint if available."""
6969 if not self .checkpoint_path .exists ():
7070 logger .info ("No checkpoint found, starting fresh" )
7171 return None
72-
72+
7373 try :
7474 with open (self .checkpoint_path , "r" , encoding = "utf-8" ) as f :
7575 data = json .load (f )
@@ -79,87 +79,87 @@ def load(self) -> Optional[CheckpointData]:
7979 except Exception as e :
8080 logger .warning (f"Failed to load checkpoint: { e } " )
8181 return None
82-
82+
8383 def save (self , checkpoint : CheckpointData ) -> None :
8484 """Save checkpoint to file."""
8585 checkpoint .updated_at = datetime .now ().isoformat ()
8686 self ._checkpoint = checkpoint
87-
87+
8888 with open (self .checkpoint_path , "w" , encoding = "utf-8" ) as f :
8989 json .dump (checkpoint .model_dump (), f , indent = 2 , ensure_ascii = False )
90-
90+
9191 logger .debug (f"Checkpoint saved: stage={ checkpoint .stage .value } " )
92-
92+
9393 def save_queries (self , queries : List [GeneratedQuery ]) -> str :
9494 """Save generated queries."""
9595 file_path = self .output_dir / self .QUERIES_FILE
96-
96+
9797 with open (file_path , "w" , encoding = "utf-8" ) as f :
9898 json .dump ([q .model_dump () for q in queries ], f , indent = 2 , ensure_ascii = False )
99-
99+
100100 logger .info (f"Saved { len (queries )} queries to { file_path } " )
101101 return str (file_path )
102-
102+
103103 def load_queries (self ) -> List [GeneratedQuery ]:
104104 """Load saved queries."""
105105 file_path = self .output_dir / self .QUERIES_FILE
106-
106+
107107 if not file_path .exists ():
108108 return []
109-
109+
110110 with open (file_path , "r" , encoding = "utf-8" ) as f :
111111 data = json .load (f )
112-
112+
113113 queries = [GeneratedQuery (** item ) for item in data ]
114114 logger .info (f"Loaded { len (queries )} queries from { file_path } " )
115115 return queries
116-
116+
117117 def save_responses (self , responses : List [Dict [str , Any ]]) -> str :
118118 """Save collected responses."""
119119 file_path = self .output_dir / self .RESPONSES_FILE
120-
120+
121121 with open (file_path , "w" , encoding = "utf-8" ) as f :
122122 json .dump (responses , f , indent = 2 , ensure_ascii = False )
123-
123+
124124 logger .info (f"Saved { len (responses )} responses to { file_path } " )
125125 return str (file_path )
126-
126+
127127 def load_responses (self ) -> List [Dict [str , Any ]]:
128128 """Load saved responses."""
129129 file_path = self .output_dir / self .RESPONSES_FILE
130-
130+
131131 if not file_path .exists ():
132132 return []
133-
133+
134134 with open (file_path , "r" , encoding = "utf-8" ) as f :
135135 responses = json .load (f )
136-
136+
137137 logger .info (f"Loaded { len (responses )} responses from { file_path } " )
138138 return responses
139-
139+
140140 def save_rubrics (self , rubrics : List [str ]) -> str :
141141 """Save generated rubrics."""
142142 file_path = self .output_dir / self .RUBRICS_FILE
143-
143+
144144 with open (file_path , "w" , encoding = "utf-8" ) as f :
145145 json .dump (rubrics , f , indent = 2 , ensure_ascii = False )
146-
146+
147147 logger .info (f"Saved { len (rubrics )} rubrics to { file_path } " )
148148 return str (file_path )
149-
149+
150150 def load_rubrics (self ) -> List [str ]:
151151 """Load saved rubrics."""
152152 file_path = self .output_dir / self .RUBRICS_FILE
153-
153+
154154 if not file_path .exists ():
155155 return []
156-
156+
157157 with open (file_path , "r" , encoding = "utf-8" ) as f :
158158 rubrics = json .load (f )
159-
159+
160160 logger .info (f"Loaded { len (rubrics )} rubrics from { file_path } " )
161161 return rubrics
162-
162+
163163 def update_stage (
164164 self ,
165165 stage : EvaluationStage ,
@@ -168,14 +168,14 @@ def update_stage(
168168 """Update checkpoint stage and save."""
169169 if self ._checkpoint is None :
170170 self ._checkpoint = CheckpointData ()
171-
171+
172172 self ._checkpoint .stage = stage
173173 for key , value in kwargs .items ():
174174 if hasattr (self ._checkpoint , key ):
175175 setattr (self ._checkpoint , key , value )
176-
176+
177177 self .save (self ._checkpoint )
178-
178+
179179 def clear (self ) -> None :
180180 """Clear all checkpoint data."""
181181 for file_name in [
@@ -187,7 +187,6 @@ def clear(self) -> None:
187187 file_path = self .output_dir / file_name
188188 if file_path .exists ():
189189 file_path .unlink ()
190-
190+
191191 self ._checkpoint = None
192192 logger .info ("Checkpoint cleared" )
193-
0 commit comments