@@ -153,6 +153,9 @@ def step(self, x):
153153
154154 def train_epoch (self , dl_train ):
155155 self .model .train ()
156+ # ScheduleFree Optimizer or SPlus
157+ if any (keyword in self .optimizer .__class__ .__name__ for keyword in ["ScheduleFree" , "SPlus" ]):
158+ self .optimizer .train ()
156159 train_loss = 0
157160 total_size = 0
158161 for x , y in dl_train :
@@ -170,6 +173,9 @@ def train_epoch(self, dl_train):
170173
171174 def val_epoch (self , dl_val ):
172175 self .model .eval ()
176+ # ScheduleFree Optimizer or SPlus
177+ if any (keyword in self .optimizer .__class__ .__name__ for keyword in ["ScheduleFree" , "SPlus" ]):
178+ self .optimizer .eval ()
173179 val_loss = 0
174180 total_size = 0
175181 for x , y in dl_val :
@@ -309,6 +315,10 @@ def run(
309315 except optuna .TrialPruned :
310316 wandb .finish ()
311317 raise
318+ except Exception as e :
319+ print (f"Runtime error during training: { e } " )
320+ wandb .finish ()
321+ raise optuna .TrialPruned ()
312322 finally :
313323 # Call trial_finished only once after all seeds are done
314324 if (
@@ -342,6 +352,7 @@ def select_group(project):
342352 groups = [
343353 d for d in os .listdir (runs_path ) if os .path .isdir (os .path .join (runs_path , d ))
344354 ]
355+ groups .sort ()
345356 if not groups :
346357 raise ValueError (f"No run groups found in { runs_path } " )
347358
@@ -354,6 +365,7 @@ def select_seed(project, group_name):
354365 seeds = [
355366 d for d in os .listdir (group_path ) if os .path .isdir (os .path .join (group_path , d ))
356367 ]
368+ seeds .sort ()
357369 if not seeds :
358370 raise ValueError (f"No seeds found in { group_path } " )
359371
0 commit comments