Skip to content

Commit 936210c

Browse files
committed
clip gradients as in paper and also make sure not to do gradient sync until last step
1 parent 422efba commit 936210c

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,21 @@ def __call__(
158158
steps = 0
159159

160160
while steps < self.num_train_steps:
161-
for _ in range(self.grad_accum_every):
161+
162+
for grad_accum_step in range(self.grad_accum_every):
163+
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)
164+
162165
inputs = next(dl)
163166

164-
loss = self.model(**inputs)
167+
with self.fabric.no_backward_sync(self.model, enabled = is_accumulating):
168+
loss = self.model(**inputs)
165169

166170
self.fabric.backward(loss / self.grad_accum_every)
167171

168172
print(f'loss: {loss.item():.3f}')
169173

174+
self.fabric.clip_gradients(self.model, self.optimizer, max_norm = self.clip_grad_norm)
175+
170176
self.optimizer.step()
171177

172178
if self.is_main:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.39"
3+
version = "0.0.40"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)