Skip to content

Commit 82be7d5

Browse files
committed
only learn pae on biomolecules and valid atom indices
1 parent 14e3ae7 commit 82be7d5

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5330,6 +5330,7 @@ def forward(
53305330
valid_distogram_mask = distogram_atom_indices >= 0 & valid_atom_len_mask
53315331
distogram_atom_indices = distogram_atom_indices.masked_fill(~valid_distogram_mask, 0)
53325332

5333+
valid_atom_indices_for_frame = None
53335334
if exists(atom_indices_for_frame):
53345335
valid_atom_indices_for_frame = (atom_indices_for_frame >= 0).all(dim = -1) & valid_atom_len_mask
53355336
atom_indices_for_frame = einx.where('b n, b n three, -> b n three', valid_atom_indices_for_frame, atom_indices_for_frame, 0)
@@ -5679,6 +5680,7 @@ def forward(
56795680
molecule_atom_indices,
56805681
molecule_pos,
56815682
distogram_atom_indices,
5683+
valid_atom_indices_for_frame,
56825684
atom_indices_for_frame,
56835685
molecule_atom_lens,
56845686
pde_labels,
@@ -5705,6 +5707,7 @@ def forward(
57055707
molecule_atom_indices,
57065708
molecule_pos,
57075709
distogram_atom_indices,
5710+
valid_atom_indices_for_frame,
57085711
atom_indices_for_frame,
57095712
molecule_atom_lens,
57105713
pde_labels,
@@ -5776,20 +5779,38 @@ def forward(
57765779
frames, _ = self.rigid_from_three_points(three_atoms)
57775780
pred_frames, _ = self.rigid_from_three_points(pred_three_atoms)
57785781

5782+
# determine mask
5783+
# must be residue or nucleotide with greater than 0 atoms
5784+
5785+
align_error_mask = (
5786+
is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim = -1) &
5787+
valid_atom_indices_for_frame
5788+
)
5789+
5790+
if ch_atom_res:
5791+
align_error_mask = batch_repeat_interleave(align_error_mask, molecule_atom_lens)
5792+
57795793
# align error
57805794

57815795
align_error = self.compute_alignment_error(
57825796
denoised_atom_pos if ch_atom_res else denoised_molecule_pos,
57835797
atom_pos if ch_atom_res else molecule_pos,
57845798
pred_frames,
57855799
frames,
5800+
mask = align_error_mask,
57865801
molecule_atom_lens = molecule_atom_lens
57875802
)
57885803

57895804
# calculate pae labels as alignment error binned to 64 (0 - 32A)
57905805

57915806
pae_labels = distance_to_bins(align_error, self.pae_bins)
57925807

5808+
# set ignore index for invalid molecules or frames (todo: figure out what is meant by invalid frame)
5809+
5810+
pair_align_error_mask = to_pairwise_mask(align_error_mask)
5811+
5812+
pae_labels = einx.where('b i j, b i j, -> b i j', pair_align_error_mask, pae_labels, ignore)
5813+
57935814
# confidence head
57945815

57955816
should_call_confidence_head = any([*map(exists, confidence_head_labels)])

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)