Skip to content

Commit 8dd2aad

Browse files
authored
Update alphafold3.py (#176)
1 parent 2331ed5 commit 8dd2aad

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6083,8 +6083,12 @@ def forward(
60836083
is_nucleotide = is_rna | is_dna
60846084
is_polymer = is_protein | is_rna | is_dna
60856085

6086-
is_any_nucleotide_pair = to_pairwise_mask(is_nucleotide)
6087-
is_any_polymer_pair = to_pairwise_mask(is_polymer)
6086+
is_any_nucleotide_pair = einx.logical_and(
6087+
'... i, ... j -> ... i j', torch.ones_like(is_nucleotide), is_nucleotide
6088+
)
6089+
is_any_polymer_pair = einx.logical_and(
6090+
'... i, ... j -> ... i j', torch.ones_like(is_polymer), is_polymer
6091+
)
60886092

60896093
inclusion_radius = torch.where(
60906094
is_any_nucleotide_pair,
@@ -6094,7 +6098,11 @@ def forward(
60946098

60956099
is_token_center_atom = torch.zeros_like(atom_pos[..., 0], dtype=torch.bool)
60966100
is_token_center_atom[torch.arange(batch_size).unsqueeze(1), molecule_atom_indices] = True
6097-
is_any_token_center_atom_pair = to_pairwise_mask(is_token_center_atom)
6101+
is_any_token_center_atom_pair = einx.logical_and(
6102+
'... i, ... j -> ... i j',
6103+
torch.ones_like(is_token_center_atom),
6104+
is_token_center_atom,
6105+
)
60986106

60996107
# compute masks, avoiding self term
61006108

0 commit comments

Comments
 (0)