Skip to content

Commit 1047480

Browse files
authored
Update model_utils.py (#198)
1 parent 48373bd commit 1047480

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

alphafold3_pytorch/utils/model_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,12 +636,12 @@ def get_indices_three_closest_atom_pos(
636636
:param mask: The mask to apply.
637637
:return: The indices of the three closest atoms to each atom.
638638
"""
639-
prec_dims, device = atom_pos.shape[:-2], atom_pos.device
639+
atom_dims, device = atom_pos.shape[-3:-1], atom_pos.device
640640
num_atoms, has_batch = atom_pos.shape[-2], atom_pos.ndim == 3
641641
batch_size = 1 if not has_batch else atom_pos.shape[0]
642642

643-
if not exists(mask) and num_atoms < 3:
644-
return atom_pos.new_full((*prec_dims, 3), -1).long()
643+
if num_atoms < 3:
644+
return atom_pos.new_full((*atom_dims, 3), -1).long()
645645

646646
if not has_batch:
647647
atom_pos = rearrange(atom_pos, "... -> 1 ...")

0 commit comments

Comments
 (0)