Skip to content

Commit abd54ae

Browse files
committed
gradient accumulation
1 parent 4b61f49 commit abd54ae

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import Sequential
4+
from torch.optim import Adam
5+
from torch.nn.functional import l1_loss
6+
from copy import deepcopy
7+
8+
model = Sequential(
9+
nn.Linear(32, 64, bias=True),
10+
nn.LeakyReLU(),
11+
nn.Linear(64, 16),
12+
nn.Sigmoid(),
13+
)
14+
15+
inputs = torch.randn(32, 32)
16+
targets = torch.randn(32, 16)
17+
18+
def gather_grad(model):
19+
grads = {}
20+
for name, param in model.named_parameters():
21+
if param.grad is not None:
22+
grads[name] = param.grad.clone()
23+
return grads
24+
25+
def grad_batch(model):
26+
optimizer = Adam(model.parameters(), lr=1e-4)
27+
optimizer.zero_grad()
28+
outputs = model(inputs)
29+
loss = l1_loss(targets, outputs)
30+
loss.backward()
31+
return gather_grad(model)
32+
33+
def grad_batch_accum(model, chunck):
34+
optimizer = Adam(model.parameters(), lr=1e-4)
35+
optimizer.zero_grad()
36+
assert inputs.shape[0] % chunck == 0
37+
chunck_size = inputs.shape[0] // chunck
38+
split_inputs = torch.split(inputs, chunck_size)
39+
split_targets = torch.split(targets, chunck_size)
40+
for idx, sin in enumerate(split_inputs):
41+
split_outputs = model(sin)
42+
loss = l1_loss(split_outputs, split_targets[idx]) / chunck
43+
loss.backward()
44+
return gather_grad(model)
45+
46+
grad_b = grad_batch(deepcopy(model))
47+
grad_ba_8 = grad_batch_accum(deepcopy(model), 8)
48+
grad_ba_16 = grad_batch_accum(deepcopy(model), 16)
49+
grad_ba_32 = grad_batch_accum(deepcopy(model), 32)
50+
51+
for name in grad_b.keys():
52+
gb = grad_b[name]
53+
print(name)
54+
for gba in [grad_ba_8[name], grad_ba_16[name], grad_ba_32[name]]:
55+
diff = (gb - gba).abs().max()
56+
print(f'\t{diff.item():.9f}')

0 commit comments

Comments
 (0)