File tree Expand file tree Collapse file tree 2 files changed +10
-4
lines changed
Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -40,6 +40,7 @@ class train_config:
4040
4141 # training spec
4242 batch_size : int = 2
43+ grad_accum_steps : int = 1
4344 num_steps : int = 1000000
4445 training_stage : str = "initial"
4546 learning_rate : float = 3e-4
Original file line number Diff line number Diff line change @@ -80,6 +80,7 @@ def train(
8080 run ["hparams" ] = asdict (cfg )
8181
8282 model .train ()
83+ optimizer .zero_grad ()
8384 ddp_stats = torch .zeros (3 ).to (local_rank )
8485
8586 start = time .time ()
@@ -91,20 +92,24 @@ def train(
9192 input = input .to (local_rank )
9293 label = label .to (local_rank )
9394
94- optimizer .zero_grad ()
9595 output = model (input )
9696 output = output .logits if hasattr (output , "logits" ) else output
9797 ce_loss = torch .nn .CrossEntropyLoss ()
9898 loss = ce_loss (output .view (- 1 , output .size (- 1 )), label .view (- 1 ).long ())
9999 loss = loss + .0001 * torch .logsumexp (output , dim = - 1 ).pow (2 ).mean ()
100+ loss = loss / cfg .grad_accum_steps
100101
101102 loss .backward ()
102- ddp_stats [1 ] += model .clip_grad_norm_ (cfg .grad_clip_thresh ).item ()
103- optimizer .step ()
103+
104+ if batch_idx % cfg .grad_accum_steps == 0 :
105+ ddp_stats [1 ] += model .clip_grad_norm_ (cfg .grad_clip_thresh ).item ()
106+ optimizer .step ()
107+ optimizer .zero_grad ()
108+ ddp_stats [2 ] += 1
109+
104110 scheduler .step ()
105111
106112 ddp_stats [0 ] += loss .item ()
107- ddp_stats [2 ] += 1
108113
109114 if profiler :
110115 profiler .step ()
You can’t perform that action at this time.
0 commit comments