@@ -5330,6 +5330,7 @@ def forward(
53305330 valid_distogram_mask = distogram_atom_indices >= 0 & valid_atom_len_mask
53315331 distogram_atom_indices = distogram_atom_indices .masked_fill (~ valid_distogram_mask , 0 )
53325332
5333+ valid_atom_indices_for_frame = None
53335334 if exists (atom_indices_for_frame ):
53345335 valid_atom_indices_for_frame = (atom_indices_for_frame >= 0 ).all (dim = - 1 ) & valid_atom_len_mask
53355336 atom_indices_for_frame = einx .where ('b n, b n three, -> b n three' , valid_atom_indices_for_frame , atom_indices_for_frame , 0 )
@@ -5679,6 +5680,7 @@ def forward(
56795680 molecule_atom_indices ,
56805681 molecule_pos ,
56815682 distogram_atom_indices ,
5683+ valid_atom_indices_for_frame ,
56825684 atom_indices_for_frame ,
56835685 molecule_atom_lens ,
56845686 pde_labels ,
@@ -5705,6 +5707,7 @@ def forward(
57055707 molecule_atom_indices ,
57065708 molecule_pos ,
57075709 distogram_atom_indices ,
5710+ valid_atom_indices_for_frame ,
57085711 atom_indices_for_frame ,
57095712 molecule_atom_lens ,
57105713 pde_labels ,
@@ -5776,20 +5779,38 @@ def forward(
57765779 frames , _ = self .rigid_from_three_points (three_atoms )
57775780 pred_frames , _ = self .rigid_from_three_points (pred_three_atoms )
57785781
5782+ # determine mask
5783+ # must be residue or nucleotide with greater than 0 atoms
5784+
5785+ align_error_mask = (
5786+ is_molecule_types [..., IS_BIOMOLECULE_INDICES ].any (dim = - 1 ) &
5787+ valid_atom_indices_for_frame
5788+ )
5789+
5790+ if ch_atom_res :
5791+ align_error_mask = batch_repeat_interleave (align_error_mask , molecule_atom_lens )
5792+
57795793 # align error
57805794
57815795 align_error = self .compute_alignment_error (
57825796 denoised_atom_pos if ch_atom_res else denoised_molecule_pos ,
57835797 atom_pos if ch_atom_res else molecule_pos ,
57845798 pred_frames ,
57855799 frames ,
5800+ mask = align_error_mask ,
57865801 molecule_atom_lens = molecule_atom_lens
57875802 )
57885803
57895804 # calculate pae labels as alignment error binned to 64 (0 - 32A)
57905805
57915806 pae_labels = distance_to_bins (align_error , self .pae_bins )
57925807
5808+ # set ignore index for invalid molecules or frames (todo: figure out what is meant by invalid frame)
5809+
5810+ pair_align_error_mask = to_pairwise_mask (align_error_mask )
5811+
5812+ pae_labels = einx .where ('b i j, b i j, -> b i j' , pair_align_error_mask , pae_labels , ignore )
5813+
57935814 # confidence head
57945815
57955816 should_call_confidence_head = any ([* map (exists , confidence_head_labels )])
0 commit comments