Skip to content

Commit f7e9d6e

Browse files
committed
Add supports for schedulefree optimizer and SPlus
1 parent 324f3ee commit f7e9d6e

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def step(self, x):
153153

154154
def train_epoch(self, dl_train):
155155
self.model.train()
156+
# ScheduleFree Optimizer or SPlus
157+
if any(keyword in self.optimizer.__class__.__name__ for keyword in ["ScheduleFree", "SPlus"]):
158+
self.optimizer.train()
156159
train_loss = 0
157160
total_size = 0
158161
for x, y in dl_train:
@@ -170,6 +173,9 @@ def train_epoch(self, dl_train):
170173

171174
def val_epoch(self, dl_val):
172175
self.model.eval()
176+
# ScheduleFree Optimizer or SPlus
177+
if any(keyword in self.optimizer.__class__.__name__ for keyword in ["ScheduleFree", "SPlus"]):
178+
self.optimizer.eval()
173179
val_loss = 0
174180
total_size = 0
175181
for x, y in dl_val:
@@ -309,6 +315,10 @@ def run(
309315
except optuna.TrialPruned:
310316
wandb.finish()
311317
raise
318+
except Exception as e:
319+
print(f"Runtime error during training: {e}")
320+
wandb.finish()
321+
raise optuna.TrialPruned()
312322
finally:
313323
# Call trial_finished only once after all seeds are done
314324
if (
@@ -342,6 +352,7 @@ def select_group(project):
342352
groups = [
343353
d for d in os.listdir(runs_path) if os.path.isdir(os.path.join(runs_path, d))
344354
]
355+
groups.sort()
345356
if not groups:
346357
raise ValueError(f"No run groups found in {runs_path}")
347358

@@ -354,6 +365,7 @@ def select_seed(project, group_name):
354365
seeds = [
355366
d for d in os.listdir(group_path) if os.path.isdir(os.path.join(group_path, d))
356367
]
368+
seeds.sort()
357369
if not seeds:
358370
raise ValueError(f"No seeds found in {group_path}")
359371

0 commit comments

Comments
 (0)