@@ -208,24 +208,42 @@ def hard_validate_atom_indices_ascending(
208208@typecheck
209209def get_indices_three_closest_atom_pos (
210210 atom_pos : Float ['... n d' ],
211+ mask : Bool ['... n' ] | None = None
211212) -> Int ['... 3' ]:
212213
213214 prec_dims , device = atom_pos .shape [:- 2 ], atom_pos .device
214215 num_atoms , has_batch = atom_pos .shape [- 2 ], atom_pos .ndim == 3
215216
216- if num_atoms < 3 :
217+ if not exists ( mask ) and num_atoms < 3 :
217218 return atom_pos .new_full ((* prec_dims , 3 ), - 1 ).long ()
218219
219220 if not has_batch :
220221 atom_pos = rearrange (atom_pos , '... -> 1 ...' )
221222
223+ if exists (mask ):
224+ mask = rearrange (mask , '... -> 1 ...' )
225+
226+ # figure out which set of atoms are less than 3 for masking out later
227+
228+ if exists (mask ):
229+ insufficient_atom_mask = mask .sum (dim = - 1 ) < 3
230+
231+ # get distances between all atoms
232+
222233 atom_dist = torch .cdist (atom_pos , atom_pos )
223234
224235 # mask out the distance to self
225236
226237 eye = torch .eye (num_atoms , device = device , dtype = torch .bool )
227238
228- atom_dist .masked_fill_ (eye , 1e4 )
239+ mask_value = 1e4
240+ atom_dist .masked_fill_ (eye , mask_value )
241+
242+ # take care of padding
243+
244+ if exists (mask ):
245+ pair_mask = einx .logical_and ('... i, ... j -> ... i j' , mask , mask )
246+ atom_dist .masked_fill_ (~ pair_mask , mask_value )
229247
230248 # will use topk on the negative of the distance
231249
@@ -245,6 +263,11 @@ def get_indices_three_closest_atom_pos(
245263 best_two_atom_neighbors [..., 1 ],
246264 ), 'b *' )
247265
266+ # mask out
267+
268+ if exists (mask ):
269+ three_atom_indices = einx .where ('..., ... three, -> ... three' , ~ insufficient_atom_mask , three_atom_indices , - 1 )
270+
248271 if not has_batch :
249272 three_atom_indices = rearrange (three_atom_indices , '1 ... -> ...' )
250273
@@ -261,11 +284,12 @@ def get_angle_between_edges(
261284@typecheck
262285def get_frames_from_atom_pos (
263286 atom_pos : Float ['... n d' ],
287+ mask : Bool ['... n' ] | None = None ,
264288 filter_colinear_pos : bool = False ,
265289 is_colinear_angle_thres : float = 25. # they use 25 degrees as a way of filtering out invalid frames
266290) -> Int ['... 3' ]:
267291
268- frames = get_indices_three_closest_atom_pos (atom_pos )
292+ frames = get_indices_three_closest_atom_pos (atom_pos , mask = mask )
269293
270294 if not filter_colinear_pos :
271295 return frames
0 commit comments