Skip to content

Commit 3875b27

Browse files
committed
first prepare for missing representative atom for the distogram
1 parent 3a7d62e commit 3875b27

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,12 +3375,14 @@ def forward(
33753375
molecule_atom_lens = molecule_atom_lens.masked_fill(~valid_atom_len_mask, 0)
33763376

33773377
if exists(molecule_atom_indices):
3378-
molecule_atom_indices = molecule_atom_indices.masked_fill(~valid_atom_len_mask, 0)
3379-
assert (molecule_atom_indices < molecule_atom_lens)[valid_atom_len_mask].all(), 'molecule_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
3378+
valid_molecule_atom_mask = molecule_atom_indices >= 0 & valid_atom_len_mask
3379+
molecule_atom_indices = molecule_atom_indices.masked_fill(~valid_molecule_atom_mask, 0)
3380+
assert (molecule_atom_indices < molecule_atom_lens)[valid_molecule_atom_mask].all(), 'molecule_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
33803381

33813382
if exists(distogram_atom_indices):
3382-
distogram_atom_indices = distogram_atom_indices.masked_fill(~valid_atom_len_mask, 0)
3383-
assert (distogram_atom_indices < molecule_atom_lens)[valid_atom_len_mask].all(), 'distogram_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
3383+
valid_distogram_mask = distogram_atom_indices >= 0 & valid_atom_len_mask
3384+
distogram_atom_indices = distogram_atom_indices.masked_fill(~valid_distogram_mask, 0)
3385+
assert (distogram_atom_indices < molecule_atom_lens)[valid_distogram_mask].all(), 'distogram_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
33843386

33853387
assert exists(molecule_atom_lens) or exists(atom_mask)
33863388

@@ -3629,6 +3631,11 @@ def forward(
36293631
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', molecule_dist, self.distance_bins).abs()
36303632
distance_labels = dist_from_dist_bins.argmin(dim = -1)
36313633

3634+
# account for representative distogram atom missing from residue (-1 set on distogram_atom_indices field)
3635+
3636+
valid_distogram_mask = einx.logical_and('b i, b j -> b i j', valid_distogram_mask, valid_distogram_mask)
3637+
distance_labels.masked_fill_(~valid_distogram_mask, ignore)
3638+
36323639
if exists(distance_labels):
36333640
distance_labels = torch.where(pairwise_mask, distance_labels, ignore)
36343641
distogram_logits = self.distogram_head(pairwise)
@@ -3660,7 +3667,6 @@ def forward(
36603667
relative_position_encoding,
36613668
additional_molecule_feats,
36623669
is_molecule_types,
3663-
distogram_atom_indices,
36643670
molecule_atom_indices,
36653671
molecule_atom_lens,
36663672
pae_labels,
@@ -3684,7 +3690,6 @@ def forward(
36843690
relative_position_encoding,
36853691
additional_molecule_feats,
36863692
is_molecule_types,
3687-
distogram_atom_indices,
36883693
molecule_atom_indices,
36893694
molecule_atom_lens,
36903695
pae_labels,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.126"
3+
version = "0.1.127"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)