Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ def parse_args(base_parser, args, namespace):
# Distributed args
parser.add_argument('--distributed_backend', default=None, type=str, required=False,
choices=distributed.registered_backends()) # distributed backend type
parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)

return parser.parse_args(args, namespace)
2 changes: 2 additions & 0 deletions src/config/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ def parse_args(base_parser, args, namespace):
# Distributed args
parser.add_argument('--distributed_backend', default=None, type=str, required=False,
choices=distributed.registered_backends()) # distributed backend type
parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)

return parser.parse_args(args, namespace)
24 changes: 23 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ def main(args):
elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed
print(f"Already found experiment '{ckpt_path}'.\nSkipping.")
sys.exit(0)
itr = 0
rng_state_dict = None
distributed_backend.sync()
checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file]
if checkpoints:
last_ckpt_path = sorted(checkpoints)[-1]
print(f"Training interrupted, resuming from {last_ckpt_path}")
checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path))
model_state_dict = {k.replace("_orig_mod.", ""):v for k,v in checkpoint['model'].items()}
# FIXME checkpoints from compiled model have _orig_mod keyword

optimizer_state_dict = checkpoint['optimizer']
rng_state_dict = {
module: checkpoint[module] for module in ["cpu_rng_state", "gpu_rng_state", "numpy_rng_state", "py_rng_state"]
}

model.load_state_dict(model_state_dict)
opt.load_state_dict(optimizer_state_dict)
itr=checkpoint['itr']
if not scheduler is None:
scheduler_state_dict = checkpoint['scheduler']
scheduler.load_state_dict(scheduler_state_dict)

if args.model == 'base': # all train functions have the same interface
train = train_base
Expand All @@ -125,7 +147,7 @@ def main(args):
stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length,
eval_freq=args.eval_freq,
distributed_backend=distributed_backend,
ckpt_path=f"{ckpt_path}/ckpt.pt", extra_args=args)
ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args)

args.device = None
args.dtype = None
Expand Down
29 changes: 25 additions & 4 deletions src/optim/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
import time
import itertools
import copy

import random
import os
import numpy as np
from .utils import eval, get_batch, save_checkpoint


def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None):
device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype)
itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
substep, best_val_loss, text_table = 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible

data["train"], train_sampler = get_dataloader(
data["train"],
Expand Down Expand Up @@ -49,7 +51,13 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba

t0 = time.time()
train_epochs = 0
if not rng_state_dict is None:
torch.set_rng_state(rng_state_dict["cpu_rng_state"])
torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"])
np.random.set_state(rng_state_dict["numpy_rng_state"])
random.setstate(rng_state_dict["py_rng_state"])
while itr < iterations:

for microstep_idx in range(acc_steps): # gradient accumulation
x, y = get_batch(data_train_iter, device=extra_args.device)
with type_ctx:
Expand Down Expand Up @@ -122,7 +130,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba

model.train()
t0 = time.time()

if distributed_backend.is_master_process():
if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt")
save_checkpoint(distributed_backend=distributed_backend,
model=model,
opt=opt,
scheduler=scheduler,
itr=itr,
cpu_rng_state=torch.get_rng_state(),
gpu_rng_state=torch.cuda.get_rng_state(),
numpy_rng_state=np.random.get_state(),
py_rng_state=random.getstate(),
ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt"))

if distributed_backend.is_master_process():
print(f"saving checkpoint to {ckpt_path}")
save_checkpoint(distributed_backend=distributed_backend,
Expand Down
28 changes: 25 additions & 3 deletions src/optim/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import wandb
import time
import copy

import numpy as np
import random
from .utils import eval_sparse, get_batch, eval_sweep_dropk, save_checkpoint


def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args):
def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args, itr=0, rng_state_dict=None):
device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype)
itr, substep, best_val_loss, text_table, sparsity_plot = 0, 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible
substep, best_val_loss, text_table, sparsity_plot = 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible
data["train"] = get_dataloader(
data["train"],
sequence_length=sequence_length,
Expand Down Expand Up @@ -42,7 +43,14 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps,
model.train()

t0 = time.time()
if not rng_state_dict is None:
torch.set_rng_state(rng_state_dict["cpu_rng_state"])
torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"])
np.random.set_state(rng_state_dict["numpy_rng_state"])
random.setstate(rng_state_dict["py_rng_state"])

while itr < iterations:

for microstep_idx in range(acc_steps): # gradient accumulation
x, y = get_batch(data_train_iter, device=extra_args.device)
with type_ctx:
Expand Down Expand Up @@ -129,6 +137,20 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps,

model.train()
t0 = time.time()
if distributed_backend.is_master_process():
if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt")
save_checkpoint(distributed_backend=distributed_backend,
model=model,
opt=opt,
scheduler=scheduler,
itr=itr,
cpu_rng_state=torch.get_rng_state(),
gpu_rng_state=torch.cuda.get_rng_state(),
numpy_rng_state=np.random.get_state(),
py_rng_state=random.getstate(),
ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt")


if distributed_backend.is_master_process():
print(f"saving checkpoint to {ckpt_path}")
Expand Down