Skip to content

Commit d52c3c0

Browse files
committed
log the entire loss breakdown
1 parent 340dd87 commit d52c3c0

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,9 +1757,9 @@ def forward(
17571757
# https://arxiv.org/abs/2206.00364
17581758

17591759
class DiffusionLossBreakdown(NamedTuple):
1760-
mse: Float['']
1761-
bond: Float['']
1762-
smooth_lddt: Float['']
1760+
diffusion_mse: Float['']
1761+
diffusion_bond: Float['']
1762+
diffusion_smooth_lddt: Float['']
17631763

17641764
class ElucidatedAtomDiffusionReturn(NamedTuple):
17651765
loss: Float['']
@@ -2618,14 +2618,17 @@ def forward(
26182618
# main class
26192619

26202620
class LossBreakdown(NamedTuple):
2621-
diffusion: Float['']
2621+
total_loss: Float['']
2622+
total_diffusion: Float['']
26222623
distogram: Float['']
26232624
pae: Float['']
26242625
pde: Float['']
26252626
plddt: Float['']
26262627
resolved: Float['']
26272628
confidence: Float['']
2628-
diffusion_loss_breakdown: DiffusionLossBreakdown
2629+
diffusion_mse: Float['']
2630+
diffusion_bond: Float['']
2631+
diffusion_smooth_lddt: Float['']
26292632

26302633
class Alphafold3(Module):
26312634
""" Algorithm 1 """
@@ -3215,14 +3218,15 @@ def forward(
32153218
return loss
32163219

32173220
loss_breakdown = LossBreakdown(
3221+
total_loss = loss,
3222+
total_diffusion = diffusion_loss,
32183223
pae = pae_loss,
32193224
pde = pde_loss,
32203225
plddt = plddt_loss,
32213226
resolved = resolved_loss,
32223227
distogram = distogram_loss,
3223-
diffusion = diffusion_loss,
32243228
confidence = confidence_loss,
3225-
diffusion_loss_breakdown = diffusion_loss_breakdown
3229+
**diffusion_loss_breakdown._asdict()
32263230
)
32273231

32283232
return loss, loss_breakdown

alphafold3_pytorch/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,14 @@ def __call__(
173173
inputs = next(dl)
174174

175175
with self.fabric.no_backward_sync(self.model, enabled = is_accumulating):
176-
loss = self.model(**inputs)
176+
loss, loss_breakdown = self.model(
177+
**inputs,
178+
return_loss_breakdown = True
179+
)
177180

178181
self.fabric.backward(loss / self.grad_accum_every)
179182

180-
self.log(loss = loss)
183+
self.log(**loss_breakdown._asdict())
181184

182185
self.print(f'loss: {loss.item():.3f}')
183186

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.45"
3+
version = "0.0.46"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)