Skip to content

Commit e2e8e46

Browse files
committed
add distogram loss off pairwise trunk for starters
1 parent 1936f1e commit e2e8e46

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,6 +2011,12 @@ def __init__(
20112011
self.loss_confidence_weight = loss_confidence_weight
20122012
self.loss_diffusion_weight = loss_diffusion_weight
20132013

2014+
self.register_buffer('dummy', torch.tensor(0), persistent = False)
2015+
2016+
@property
2017+
def device(self):
2018+
return self.dummy.device
2019+
20142020
@typecheck
20152021
def forward(
20162022
self,
@@ -2048,6 +2054,7 @@ def forward(
20482054
w = self.atoms_per_window
20492055

20502056
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
2057+
pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask)
20512058

20522059
# init recycled single and pairwise
20532060

@@ -2107,6 +2114,22 @@ def forward(
21072114
# otherwise will sample the atomic coordinates
21082115

21092116
labels = (distance_labels, pae_labels, pde_labels, plddt_labels, resolved_labels)
2110-
return_loss = any([*filter(exists, labels)])
2117+
return_loss = any([*map(exists, labels)])
2118+
2119+
if not return_loss:
2120+
return torch.randn((*atom_inputs.shape[:2], 3), device = self.device)
2121+
2122+
# calculate all logits and losses
2123+
2124+
ignore = self.ignore_index
2125+
2126+
if exists(distance_labels):
2127+
distance_labels = torch.where(pairwise_mask, distance_labels, ignore)
2128+
distogram_logits = self.distogram_head(pairwise)
2129+
distogram_loss = F.cross_entropy(distogram_logits, distance_labels, ignore_index = ignore)
2130+
2131+
loss = (
2132+
distogram_loss * self.loss_distogram_weight
2133+
)
21112134

2112-
return torch.tensor(0.)
2135+
return loss

tests/test_readme.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,13 @@ def test_alphafold3():
246246

247247
msa = torch.randn(2, 7, seq_len, 64)
248248

249+
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
250+
249251
alphafold3 = Alphafold3(
250252
dim_atom_inputs = 77,
251253
dim_additional_residue_feats = 33,
252-
dim_template_feats = 44
254+
dim_template_feats = 44,
255+
num_dist_bins = 38
253256
)
254257

255258
loss = alphafold3(
@@ -260,7 +263,8 @@ def test_alphafold3():
260263
additional_residue_feats = additional_residue_feats,
261264
msa = msa,
262265
templates = template_feats,
263-
template_mask = template_mask
266+
template_mask = template_mask,
267+
distance_labels = distance_labels
264268
)
265269

266-
print(loss)
270+
loss.backward()

0 commit comments

Comments
 (0)