Skip to content

Commit e878ad7

Browse files
committed
able to return loss breakdown
1 parent d201355 commit e878ad7

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
Sequential,
3535
)
3636

37-
from typing import Literal, Tuple
37+
from typing import Literal, Tuple, NamedTuple
3838

3939
from alphafold3_pytorch.typing import (
4040
Float,
@@ -2326,13 +2326,14 @@ def forward(
23262326

23272327
# main class
23282328

2329-
LossBreakdown = namedtuple('LossBreakdown', [
2330-
'distogram',
2331-
'pae',
2332-
'pde',
2333-
'plddt',
2334-
'resolved'
2335-
])
2329+
class LossBreakdown(NamedTuple):
2330+
diffusion: Float['']
2331+
distogram: Float['']
2332+
pae: Float['']
2333+
pde: Float['']
2334+
plddt: Float['']
2335+
resolved: Float['']
2336+
confidence: Float['']
23362337

23372338
class Alphafold3(Module):
23382339
""" Algorithm 1 """
@@ -2561,7 +2562,8 @@ def forward(
25612562
pde_labels: Int['b n n'] | None = None,
25622563
plddt_labels: Int['b n'] | None = None,
25632564
resolved_labels: Int['b n'] | None = None,
2564-
) -> Float['b m 3'] | Float['']:
2565+
return_loss_breakdown = False
2566+
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
25652567

25662568
w = self.atoms_per_window
25672569

@@ -2754,4 +2756,17 @@ def forward(
27542756
confidence_loss * self.loss_confidence_weight
27552757
)
27562758

2757-
return loss
2759+
if not return_loss_breakdown:
2760+
return loss
2761+
2762+
loss_breakdown = LossBreakdown(
2763+
pae = pae_loss,
2764+
pde = pde_loss,
2765+
plddt = plddt_loss,
2766+
resolved = resolved_loss,
2767+
distogram = distogram_loss,
2768+
diffusion = diffusion_loss,
2769+
confidence = confidence_loss
2770+
)
2771+
2772+
return loss, loss_breakdown

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.11"
3+
version = "0.0.12"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def test_alphafold3():
380380
),
381381
)
382382

383-
loss = alphafold3(
383+
loss, breakdown = alphafold3(
384384
num_recycling_steps = 2,
385385
atom_inputs = atom_inputs,
386386
atom_mask = atom_mask,
@@ -395,7 +395,8 @@ def test_alphafold3():
395395
pae_labels = pae_labels,
396396
pde_labels = pde_labels,
397397
plddt_labels = plddt_labels,
398-
resolved_labels = resolved_labels
398+
resolved_labels = resolved_labels,
399+
return_loss_breakdown = True
399400
)
400401

401402
loss.backward()

0 commit comments

Comments
 (0)