Skip to content

Commit 6b27464

Browse files
committed
add grad accum support
1 parent 515b2a5 commit 6b27464

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

fms_fsdp/config/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class train_config:
4040

4141
# training spec
4242
batch_size: int = 2
43+
grad_accum_steps: int = 1
4344
num_steps: int = 1000000
4445
training_stage: str = "initial"
4546
learning_rate: float = 3e-4

fms_fsdp/utils/train_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def train(
8080
run["hparams"] = asdict(cfg)
8181

8282
model.train()
83+
optimizer.zero_grad()
8384
ddp_stats = torch.zeros(3).to(local_rank)
8485

8586
start = time.time()
@@ -91,20 +92,24 @@ def train(
9192
input = input.to(local_rank)
9293
label = label.to(local_rank)
9394

94-
optimizer.zero_grad()
9595
output = model(input)
9696
output = output.logits if hasattr(output, "logits") else output
9797
ce_loss = torch.nn.CrossEntropyLoss()
9898
loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long())
9999
loss = loss + .0001 * torch.logsumexp(output, dim=-1).pow(2).mean()
100+
loss = loss / cfg.grad_accum_steps
100101

101102
loss.backward()
102-
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
103-
optimizer.step()
103+
104+
if batch_idx % cfg.grad_accum_steps == 0:
105+
ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item()
106+
optimizer.step()
107+
optimizer.zero_grad()
108+
ddp_stats[2] += 1
109+
104110
scheduler.step()
105111

106112
ddp_stats[0] += loss.item()
107-
ddp_stats[2] += 1
108113

109114
if profiler:
110115
profiler.step()

0 commit comments

Comments
 (0)