Skip to content

Commit 1cde46f

Browse files
committed
further clarity with einx.get_at
1 parent 2b826e6 commit 1cde46f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,9 +3360,8 @@ def compute_full_complex_metric(
33603360
indices = repeat_consecutive_with_lens(indices, molecule_atom_lens)
33613361
valid_indices = repeat_consecutive_with_lens(valid_indices, molecule_atom_lens)
33623362

3363-
expand_indices = indices.unsqueeze(-1).expand(-1, -1, is_molecule_types.shape[-1])
33643363
# broadcast is_molecule_types to atom
3365-
atom_is_molecule_types = torch.gather(is_molecule_types, 1, expand_indices) * valid_indices[..., None]
3364+
atom_is_molecule_types = einx.get_at('b [n] is_type, b m -> b m is_type', is_molecule_types, indices) * valid_indices[..., None]
33663365

33673366
confidence_score = self.compute_confidence_score(
33683367
confidence_head_logits, asym_id, has_frame, multimer_mode=True

0 commit comments

Comments
 (0)