Skip to content

Commit ad23ecc

Browse files
committed
hook up the smooth lddt loss end2end, thanks to @joseph-c-kim
1 parent 4233f89 commit ad23ecc

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Getting a fair number of emails. You can chat with me about this work <a href="h
88

99
## Appreciation
1010

11-
- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the relative positional encoding module!
11+
- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the Relative Positional Encoding and the Smooth LDDT Loss!
1212

1313
## Install
1414

alphafold3_pytorch/alphafold3.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from collections import namedtuple
2424

2525
import torch
26-
from torch import nn
26+
from torch import nn, sigmoid
2727
from torch import Tensor
2828
import torch.nn.functional as F
2929

@@ -80,15 +80,15 @@ def unpack_one(t, ps, pattern):
8080

8181
# Loss functions
8282

83-
def smoothlddtloss(
83+
@typecheck
84+
def calc_smooth_lddt_loss(
8485
denoised: Float['b m 3'],
8586
ground_truth: Float['b m 3'],
8687
is_rna_per_atom: Float['b m'],
8788
is_dna_per_atom: Float['b m']
88-
) -> Float['b']:
89-
from torch import sigmoid
89+
) -> Float[' b']:
9090

91-
m = is_rna_per_atom.shape[-1]
91+
m, device = is_rna_per_atom.shape[-1], denoised.device
9292

9393
dx_denoised = torch.cdist(denoised, denoised)
9494
dx_gt = torch.cdist(ground_truth, ground_truth)
@@ -102,10 +102,11 @@ def smoothlddtloss(
102102
mask = einx.multiply('b i, b j -> b i j', is_nuc, is_nuc)
103103
c = (dx_gt < 30) * mask + (dx_gt < 15) * (1 - mask)
104104

105-
num = einx.sum('b [...]', c * eps * (1 - torch.eye(m))) / (m**2 - m)
106-
den = einx.sum('b [...]', c * (1 - torch.eye(m))) / (m**2 - m)
107-
108-
return 1 - num/den
105+
eye = torch.eye(m, device = device)
106+
num = einx.sum('b [...]', c * eps * (1 - eye)) / (m**2 - m)
107+
den = einx.sum('b [...]', c * (1 - eye)) / (m**2 - m)
108+
109+
return 1. - num/den
109110

110111
# linear and outer sum
111112
# for single repr -> pairwise pattern throughout this architecture
@@ -1699,6 +1700,8 @@ def forward(
16991700
normalized_atom_pos: Float['b m 3'],
17001701
atom_mask: Bool['b m'],
17011702
return_denoised_pos = False,
1703+
additional_residue_feats: Float['b n rf'] | None = None,
1704+
add_smooth_lddt_loss = False,
17021705
**network_condition_kwargs
17031706
) -> Float[''] | Tuple[Float[''], Float['b m 3']]:
17041707

@@ -1726,6 +1729,22 @@ def forward(
17261729

17271730
loss = losses.mean()
17281731

1732+
if add_smooth_lddt_loss:
1733+
assert exists(additional_residue_feats)
1734+
w = self.net.atoms_per_window
1735+
1736+
is_dna, is_rna = additional_residue_feats[..., 7], additional_residue_feats[..., 8]
1737+
atom_is_dna, atom_is_rna = tuple(repeat(t, 'b n -> b (n w)', w = w) for t in (is_dna, is_rna))
1738+
1739+
smooth_lddt_loss = calc_smooth_lddt_loss(
1740+
denoised,
1741+
normalized_atom_pos,
1742+
atom_is_dna,
1743+
atom_is_rna
1744+
).mean()
1745+
1746+
loss = loss + smooth_lddt_loss
1747+
17291748
if not return_denoised_pos:
17301749
return loss
17311750

@@ -2354,7 +2373,13 @@ def forward(
23542373
# otherwise, noise and make it learn to denoise
23552374

23562375
if exists(atom_pos):
2357-
diffusion_loss, denoised_atom_pos = self.edm(atom_pos, return_denoised_pos = True, **diffusion_cond)
2376+
diffusion_loss, denoised_atom_pos = self.edm(
2377+
atom_pos,
2378+
additional_residue_feats = additional_residue_feats,
2379+
add_smooth_lddt_loss = True,
2380+
return_denoised_pos = True,
2381+
**diffusion_cond
2382+
)
23582383

23592384
# calculate all logits and losses
23602385

tests/test_af3.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,23 @@
1919
Alphafold3,
2020
)
2121

22-
from alphafold3_pytorch.alphafold3 import smoothlddtloss
22+
from alphafold3_pytorch.alphafold3 import (
23+
calc_smooth_lddt_loss
24+
)
2325

24-
def test_smoothlddtloss():
26+
def test_calc_smooth_lddt_loss():
2527
denoised = torch.randn(8, 100, 3)
2628
ground_truth = torch.randn(8, 100, 3)
27-
is_rna_per_atom = torch.randint(0, 2, (8, 100))
28-
is_dna_per_atom = torch.randint(0, 2, (8, 100))
29+
is_rna_per_atom = torch.randint(0, 2, (8, 100)).float()
30+
is_dna_per_atom = torch.randint(0, 2, (8, 100)).float()
2931

30-
loss = smoothlddtloss(
32+
loss = calc_smooth_lddt_loss(
3133
denoised,
3234
ground_truth,
3335
is_rna_per_atom,
3436
is_dna_per_atom
3537
)
36-
38+
3739
assert torch.all(loss <= 1) and torch.all(loss >= 0)
3840

3941
def test_pairformer():

0 commit comments

Comments
 (0)