Skip to content

Commit 7a5a3bd

Browse files
committed
properly accumulate loss breakdown
1 parent 8f9e41d commit 7a5a3bd

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ def cycle(dataloader: DataLoader):
5252
for batch in dataloader:
5353
yield batch
5454

55+
@typecheck
56+
def accum_dict(
57+
past_losses: dict | None,
58+
losses: dict,
59+
scale: float = 1.
60+
):
61+
if not exists(past_losses):
62+
return losses
63+
64+
for loss_name in past_losses.keys():
65+
past_losses[loss_name] += losses.get(loss_name, 0.) * scale
66+
67+
return past_losses
68+
5569
def default_lambda_lr_fn(steps):
5670
# 1000 step warmup
5771

@@ -193,6 +207,9 @@ def __call__(
193207

194208
# gradient accumulation
195209

210+
total_loss = 0.
211+
train_loss_breakdown = None
212+
196213
for grad_accum_step in range(self.grad_accum_every):
197214
is_accumulating = grad_accum_step < (self.grad_accum_every - 1)
198215

@@ -207,15 +224,22 @@ def __call__(
207224
return_loss_breakdown = True
208225
)
209226

227+
# accumulate
228+
229+
scale = self.grad_accum_every ** -1
230+
231+
total_loss += loss.item() * scale
232+
train_loss_breakdown = accum_dict(train_loss_breakdown, loss_breakdown._asdict(), scale = scale)
233+
210234
# backwards
211235

212236
self.fabric.backward(loss / self.grad_accum_every)
213237

214238
# log entire loss breakdown
215239

216-
self.log(**loss_breakdown._asdict())
240+
self.log(**train_loss_breakdown)
217241

218-
self.print(f'loss: {loss.item():.3f}')
242+
self.print(f'loss: {total_loss:.3f}')
219243

220244
# clip gradients
221245

@@ -252,21 +276,30 @@ def __call__(
252276
self.ema_model.eval()
253277

254278
total_valid_loss = 0.
279+
valid_loss_breakdown = None
255280

256281
for valid_batch in self.valid_dataloader:
257-
valid_loss, valid_loss_breakdown = self.ema_model(
282+
valid_loss, loss_breakdown = self.ema_model(
258283
**valid_batch,
259284
return_loss_breakdown = True
260285
)
261286

262287
valid_batch_size = valid_batch.get('atom_inputs').shape[0]
263288
scale = valid_batch_size / self.valid_dataset_size
264289

265-
scaled_valid_loss = valid_loss.item() * scale
266-
total_valid_loss += scaled_valid_loss
290+
total_valid_loss += valid_loss.item() * scale
291+
valid_loss_breakdown = accum_dict(valid_loss_breakdown, loss_breakdown._asdict(), scale = scale)
267292

268293
self.print(f'valid loss: {total_valid_loss:.3f}')
269294

295+
# prepend valid_ to all losses for logging
296+
297+
valid_loss_breakdown = {f'valid_{k}':v for k, v in valid_loss_breakdown.items()}
298+
299+
# log
300+
301+
self.log(**valid_loss_breakdown)
302+
270303
self.wait()
271304

272305
print(f'training complete')

0 commit comments

Comments
 (0)