Skip to content

Commit 1e43b47

Browse files
committed
add ckpt for outer schedler
1 parent 4ea2735 commit 1e43b47

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

open_diloco/ckpt_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def save_checkpoint(
4040
model: torch.nn.Module,
4141
optimizer: torch.optim.Optimizer,
4242
scheduler: torch.optim.lr_scheduler.LambdaLR,
43+
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
4344
outer_optimizer: torch.optim.Optimizer | None = None,
4445
scaler: torch.cuda.amp.GradScaler | None = None,
4546
loss: float | None = None,
@@ -81,6 +82,8 @@ def save_checkpoint(
8182

8283
# 2. Save global states
8384
global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0}
85+
if outer_scheduler is not None:
86+
global_state_dict["outer_scheduler"] = outer_scheduler.state_dict()
8487
if outer_optimizer is not None:
8588
global_state_dict["outer_optimizer"] = outer_optimizer.state_dict()
8689
if scaler is not None:
@@ -95,6 +98,7 @@ def load_checkpoint(
9598
model: torch.nn.Module,
9699
optimizer: torch.optim.Optimizer,
97100
scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
101+
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
98102
outer_optimizer: torch.optim.Optimizer | None = None,
99103
scaler: torch.cuda.amp.GradScaler | None = None,
100104
data_loader: StatefulDataLoader | None = None,
@@ -139,8 +143,13 @@ def load_checkpoint(
139143
if scheduler is not None:
140144
scheduler.load_state_dict(global_state_dict["scheduler"])
141145
optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0]
146+
142147
if outer_optimizer is not None:
143148
outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"])
149+
if outer_scheduler is not None:
150+
outer_scheduler.load_state_dict(global_state_dict["outer_scheduler"])
151+
outer_optimizer.param_groups[0]["lr"] = outer_scheduler.get_last_lr()[0]
152+
144153
if scaler is not None:
145154
scaler.load_state_dict(global_state_dict["scaler"])
146155
return global_state_dict["loss"]

open_diloco/train_fsdp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def outer_scheduler_fn(opt):
539539
model=model,
540540
optimizer=optimizer.inner_optimizer,
541541
scheduler=scheduler,
542+
outer_scheduler=optimizer.outer_scheduler,
542543
outer_optimizer=optimizer.state_averager.optimizer,
543544
loss=loss_batch.item(),
544545
scaler=scaler,

0 commit comments

Comments
 (0)