We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 48373bd commit 1047480Copy full SHA for 1047480
alphafold3_pytorch/utils/model_utils.py
@@ -636,12 +636,12 @@ def get_indices_three_closest_atom_pos(
636
:param mask: The mask to apply.
637
:return: The indices of the three closest atoms to each atom.
638
"""
639
- prec_dims, device = atom_pos.shape[:-2], atom_pos.device
+ atom_dims, device = atom_pos.shape[-3:-1], atom_pos.device
640
num_atoms, has_batch = atom_pos.shape[-2], atom_pos.ndim == 3
641
batch_size = 1 if not has_batch else atom_pos.shape[0]
642
643
- if not exists(mask) and num_atoms < 3:
644
- return atom_pos.new_full((*prec_dims, 3), -1).long()
+ if num_atoms < 3:
+ return atom_pos.new_full((*atom_dims, 3), -1).long()
645
646
if not has_batch:
647
atom_pos = rearrange(atom_pos, "... -> 1 ...")
0 commit comments