@@ -139,6 +139,11 @@ def pad_to_len(t, length, value = 0, dim = -1):
139139 zeros = (0 , 0 ) * (- dim - 1 )
140140 return F .pad (t , (* zeros , 0 , max (0 , length - t .shape [dim ])), value = value )
141141
142+ def offset_only_positive (t , offset ):
143+ is_positive = t >= 0
144+ t_offsetted = t + offset
145+ return torch .where (is_positive , t_offsetted , t )
146+
142147def compose (* fns : Callable ):
143148 # for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
144149
@@ -871,9 +876,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
871876 additional_token_feats = repeat_interleave (i .additional_token_feats , token_repeats , dim = 0 )
872877 molecule_ids = repeat_interleave (i .molecule_ids , token_repeats )
873878
874- atom_indices_offsets = exclusive_cumsum (atoms_per_molecule )
875- distogram_atom_indices = i .distogram_atom_indices + atom_indices_offsets
876- molecule_atom_indices = i .molecule_atom_indices + atom_indices_offsets
879+ atom_indices_offsets = repeat_interleave (exclusive_cumsum (atoms_per_molecule ), token_repeats , dim = 0 )
877880
878881 distogram_atom_indices = repeat_interleave (i .distogram_atom_indices , token_repeats )
879882 molecule_atom_indices = repeat_interleave (i .molecule_atom_indices , token_repeats )
@@ -1018,10 +1021,6 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
10181021 atom_indices_for_frame = [default (indices , (- 1 , - 1 , - 1 )) for indices in i .atom_indices_for_frame ]
10191022 atom_indices_for_frame = tensor (atom_indices_for_frame )
10201023
1021- atom_indices_for_frame = atom_indices_for_frame + atom_indices_offsets [..., None ]
1022- valid_atom_indices_for_frame = (atom_indices_for_frame >= 0 ).all (dim = - 1 )
1023-
1024- atom_indices_for_frame = einx .where ('n, n c, -> n c' , valid_atom_indices_for_frame , atom_indices_for_frame , - 1 )
10251024 atom_indices_for_frame = repeat_interleave (atom_indices_for_frame , token_repeats , dim = 0 )
10261025
10271026 # handle maybe atompair embeds
@@ -1155,8 +1154,19 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
11551154 "n missing, n -> n missing" , missing_token_indices , distogram_atom_indices
11561155 ).any (dim = - 1 )
11571156
1157+ is_missing_atom_indices_for_frame = einx .equal (
1158+ "n missing, n c -> n c missing" , missing_token_indices , atom_indices_for_frame
1159+ ).any (dim = (- 1 , - 2 ))
1160+
11581161 molecule_atom_indices = molecule_atom_indices .masked_fill (is_missing_molecule_atom , - 1 )
11591162 distogram_atom_indices = distogram_atom_indices .masked_fill (is_missing_distogram_atom , - 1 )
1163+ atom_indices_for_frame = atom_indices_for_frame .masked_fill (is_missing_atom_indices_for_frame [..., None ], - 1 )
1164+
1165+ # offsets for all indices
1166+
1167+ distogram_atom_indices = offset_only_positive (distogram_atom_indices , atom_indices_offsets )
1168+ molecule_atom_indices = offset_only_positive (molecule_atom_indices , atom_indices_offsets )
1169+ atom_indices_for_frame = offset_only_positive (atom_indices_for_frame , atom_indices_offsets [..., None ])
11601170
11611171 # handle atom positions
11621172
0 commit comments