Skip to content

Commit c542d4b

Browse files
authored
Checkpointing and retrieval (#13)
* implemented checkpointing and retrieval * fixed scheduler and random state dict * ensure master created the ckpt folder * minor fixes
1 parent 81f9d6e commit c542d4b

File tree

5 files changed

+77
-8
lines changed

5 files changed

+77
-8
lines changed

src/config/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,6 @@ def parse_args(base_parser, args, namespace):
4545
# Distributed args
4646
parser.add_argument('--distributed_backend', default=None, type=str, required=False,
4747
choices=distributed.registered_backends()) # distributed backend type
48+
parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)
49+
4850
return parser.parse_args(args, namespace)

src/config/sparse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,6 @@ def parse_args(base_parser, args, namespace):
4747
# Distributed args
4848
parser.add_argument('--distributed_backend', default=None, type=str, required=False,
4949
choices=distributed.registered_backends()) # distributed backend type
50+
parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)
51+
5052
return parser.parse_args(args, namespace)

src/main.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,31 @@ def main(args):
109109
if not os.path.exists(ckpt_path):
110110
if distributed_backend.is_master_process():
111111
os.makedirs(ckpt_path)
112+
distributed_backend.sync()
112113
elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed
113114
print(f"Already found experiment '{ckpt_path}'.\nSkipping.")
114115
sys.exit(0)
116+
itr = 0
117+
rng_state_dict = None
118+
checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file]
119+
if checkpoints:
120+
last_ckpt_path = sorted(checkpoints)[-1]
121+
print(f"Training interrupted, resuming from {last_ckpt_path}")
122+
checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path))
123+
model_state_dict = {k.replace("_orig_mod.", ""):v for k,v in checkpoint['model'].items()}
124+
# FIXME checkpoints from compiled model have _orig_mod keyword
125+
126+
optimizer_state_dict = checkpoint['optimizer']
127+
rng_state_dict = {
128+
module: checkpoint[module] for module in ["cpu_rng_state", "gpu_rng_state", "numpy_rng_state", "py_rng_state"]
129+
}
130+
131+
model.load_state_dict(model_state_dict)
132+
opt.load_state_dict(optimizer_state_dict)
133+
itr=checkpoint['itr']
134+
if scheduler is not None:
135+
scheduler_state_dict = checkpoint['scheduler']
136+
scheduler.load_state_dict(scheduler_state_dict)
115137

116138
if args.model == 'base': # all train functions have the same interface
117139
train = train_base
@@ -125,7 +147,7 @@ def main(args):
125147
stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length,
126148
eval_freq=args.eval_freq,
127149
distributed_backend=distributed_backend,
128-
ckpt_path=f"{ckpt_path}/ckpt.pt", extra_args=args)
150+
ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args)
129151

130152
args.device = None
131153
args.dtype = None

src/optim/base.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
import time
88
import itertools
99
import copy
10-
10+
import random
11+
import os
12+
import numpy as np
1113
from .utils import eval, get_batch, save_checkpoint
1214

1315

14-
def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
16+
def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None):
1517
device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
1618
type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
1719
device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype)
18-
itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
20+
substep, best_val_loss, text_table = 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
1921

2022
data["train"], train_sampler = get_dataloader(
2123
data["train"],
@@ -49,7 +51,13 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba
4951

5052
t0 = time.time()
5153
train_epochs = 0
54+
if not rng_state_dict is None:
55+
torch.set_rng_state(rng_state_dict["cpu_rng_state"])
56+
torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"])
57+
np.random.set_state(rng_state_dict["numpy_rng_state"])
58+
random.setstate(rng_state_dict["py_rng_state"])
5259
while itr < iterations:
60+
5361
for microstep_idx in range(acc_steps): # gradient accumulation
5462
x, y = get_batch(data_train_iter, device=extra_args.device)
5563
with type_ctx:
@@ -122,7 +130,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba
122130

123131
model.train()
124132
t0 = time.time()
125-
133+
if distributed_backend.is_master_process():
134+
if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
135+
print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt")
136+
save_checkpoint(distributed_backend=distributed_backend,
137+
model=model,
138+
opt=opt,
139+
scheduler=scheduler,
140+
itr=itr,
141+
cpu_rng_state=torch.get_rng_state(),
142+
gpu_rng_state=torch.cuda.get_rng_state(),
143+
numpy_rng_state=np.random.get_state(),
144+
py_rng_state=random.getstate(),
145+
ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt"))
146+
126147
if distributed_backend.is_master_process():
127148
print(f"saving checkpoint to {ckpt_path}")
128149
save_checkpoint(distributed_backend=distributed_backend,

src/optim/sparse.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import wandb
66
import time
77
import copy
8-
8+
import numpy as np
9+
import random
910
from .utils import eval_sparse, get_batch, eval_sweep_dropk, save_checkpoint
1011

1112

12-
def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
13+
def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args, itr=0, rng_state_dict=None):
1314
device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
1415
type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
1516
device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype)
16-
itr, substep, best_val_loss, text_table, sparsity_plot = 0, 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible
17+
substep, best_val_loss, text_table, sparsity_plot = 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible
1718
data["train"] = get_dataloader(
1819
data["train"],
1920
sequence_length=sequence_length,
@@ -42,7 +43,14 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps,
4243
model.train()
4344

4445
t0 = time.time()
46+
if not rng_state_dict is None:
47+
torch.set_rng_state(rng_state_dict["cpu_rng_state"])
48+
torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"])
49+
np.random.set_state(rng_state_dict["numpy_rng_state"])
50+
random.setstate(rng_state_dict["py_rng_state"])
51+
4552
while itr < iterations:
53+
4654
for microstep_idx in range(acc_steps): # gradient accumulation
4755
x, y = get_batch(data_train_iter, device=extra_args.device)
4856
with type_ctx:
@@ -129,6 +137,20 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps,
129137

130138
model.train()
131139
t0 = time.time()
140+
if distributed_backend.is_master_process():
141+
if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
142+
print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt")
143+
save_checkpoint(distributed_backend=distributed_backend,
144+
model=model,
145+
opt=opt,
146+
scheduler=scheduler,
147+
itr=itr,
148+
cpu_rng_state=torch.get_rng_state(),
149+
gpu_rng_state=torch.cuda.get_rng_state(),
150+
numpy_rng_state=np.random.get_state(),
151+
py_rng_state=random.getstate(),
152+
ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt")
153+
132154

133155
if distributed_backend.is_master_process():
134156
print(f"saving checkpoint to {ckpt_path}")

0 commit comments

Comments
 (0)