|
5 | 5 | import wandb |
6 | 6 | import time |
7 | 7 | import copy |
8 | | - |
| 8 | +import numpy as np |
| 9 | +import random |
9 | 10 | from .utils import eval_sparse, get_batch, eval_sweep_dropk, save_checkpoint |
10 | 11 |
|
11 | 12 |
|
12 | | -def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): |
| 13 | +def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args, itr=0, rng_state_dict=None): |
13 | 14 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' |
14 | 15 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( |
15 | 16 | device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) |
16 | | - itr, substep, best_val_loss, text_table, sparsity_plot = 0, 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible |
| 17 | + substep, best_val_loss, text_table, sparsity_plot = 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible |
17 | 18 | data["train"] = get_dataloader( |
18 | 19 | data["train"], |
19 | 20 | sequence_length=sequence_length, |
@@ -42,7 +43,14 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, |
42 | 43 | model.train() |
43 | 44 |
|
44 | 45 | t0 = time.time() |
| 46 | + if not rng_state_dict is None: |
| 47 | + torch.set_rng_state(rng_state_dict["cpu_rng_state"]) |
| 48 | + torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) |
| 49 | + np.random.set_state(rng_state_dict["numpy_rng_state"]) |
| 50 | + random.setstate(rng_state_dict["py_rng_state"]) |
| 51 | + |
45 | 52 | while itr < iterations: |
| 53 | + |
46 | 54 | for microstep_idx in range(acc_steps): # gradient accumulation |
47 | 55 | x, y = get_batch(data_train_iter, device=extra_args.device) |
48 | 56 | with type_ctx: |
@@ -129,6 +137,20 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, |
129 | 137 |
|
130 | 138 | model.train() |
131 | 139 | t0 = time.time() |
| 140 | + if distributed_backend.is_master_process(): |
| 141 | + if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: |
| 142 | + print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") |
| 143 | + save_checkpoint(distributed_backend=distributed_backend, |
| 144 | + model=model, |
| 145 | + opt=opt, |
| 146 | + scheduler=scheduler, |
| 147 | + itr=itr, |
| 148 | + cpu_rng_state=torch.get_rng_state(), |
| 149 | + gpu_rng_state=torch.cuda.get_rng_state(), |
| 150 | + numpy_rng_state=np.random.get_state(), |
| 151 | + py_rng_state=random.getstate(), |
| 152 | + ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt") |
| 153 | + |
132 | 154 |
|
133 | 155 | if distributed_backend.is_master_process(): |
134 | 156 | print(f"saving checkpoint to {ckpt_path}") |
|
0 commit comments