Skip to content

Commit 4233f89

Browse files
authored
Merge pull request #9 from joseph-c-kim/smoothlddt
Implemented SmoothLDDTLoss
2 parents e956fc2 + 956da77 commit 4233f89

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,35 @@ def pack_one(t, pattern):
7878
def unpack_one(t, ps, pattern):
7979
return unpack(t, ps, pattern)[0]
8080

81+
# Loss functions
82+
83+
def smoothlddtloss(
84+
denoised: Float['b m 3'],
85+
ground_truth: Float['b m 3'],
86+
is_rna_per_atom: Float['b m'],
87+
is_dna_per_atom: Float['b m']
88+
) -> Float['b']:
89+
from torch import sigmoid
90+
91+
m = is_rna_per_atom.shape[-1]
92+
93+
dx_denoised = torch.cdist(denoised, denoised)
94+
dx_gt = torch.cdist(ground_truth, ground_truth)
95+
96+
ddx = torch.abs(dx_gt - dx_denoised)
97+
eps = 0.25 * (
98+
sigmoid(0.5 - ddx) + sigmoid(1 - ddx) + sigmoid(2 - ddx) + sigmoid(4 - ddx)
99+
)
100+
101+
is_nuc = is_rna_per_atom + is_dna_per_atom
102+
mask = einx.multiply('b i, b j -> b i j', is_nuc, is_nuc)
103+
c = (dx_gt < 30) * mask + (dx_gt < 15) * (1 - mask)
104+
105+
num = einx.sum('b [...]', c * eps * (1 - torch.eye(m))) / (m**2 - m)
106+
den = einx.sum('b [...]', c * (1 - torch.eye(m))) / (m**2 - m)
107+
108+
return 1 - num/den
109+
81110
# linear and outer sum
82111
# for single repr -> pairwise pattern throughout this architecture
83112

tests/test_af3.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@
1919
Alphafold3,
2020
)
2121

22+
from alphafold3_pytorch.alphafold3 import smoothlddtloss
23+
24+
def test_smoothlddtloss():
25+
denoised = torch.randn(8, 100, 3)
26+
ground_truth = torch.randn(8, 100, 3)
27+
is_rna_per_atom = torch.randint(0, 2, (8, 100))
28+
is_dna_per_atom = torch.randint(0, 2, (8, 100))
29+
30+
loss = smoothlddtloss(
31+
denoised,
32+
ground_truth,
33+
is_rna_per_atom,
34+
is_dna_per_atom
35+
)
36+
37+
assert torch.all(loss <= 1) and torch.all(loss >= 0)
38+
2239
def test_pairformer():
2340
single = torch.randn(2, 16, 384)
2441
pairwise = torch.randn(2, 16, 16, 128)

0 commit comments

Comments
 (0)