Skip to content

Commit db0edca

Browse files
committed
integrate the relative positional encoding into the pairwise init
1 parent b0ac11d commit db0edca

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ Implementation of <a href="https://www.nature.com/articles/s41586-024-07487-w">A
66

77
Getting a fair number of emails. You can chat with me about this work <a href="https://discord.gg/x6FuzQPQXY">here</a>
88

9+
## Appreciation
10+
11+
- <a href="https://github.com/joseph-c-kim">Joseph</a> for contributing the relative positional encoding module!
12+
913
## Install
1014

1115
```bash

alphafold3_pytorch/alphafold3.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -841,81 +841,79 @@ class RelativePositionEncoding(Module):
841841

842842
def __init__(
843843
self,
844+
*,
844845
r_max = 32,
845846
s_max = 2,
846-
out_dim = 128
847+
dim_out = 128
847848
):
848849
super().__init__()
849850
self.r_max = r_max
850851
self.s_max = s_max
851852

852-
input_dim = (2*r_max+2) + (2*r_max+2) + 1 + (2*s_max+2)
853-
self.out_embedder = LinearNoBias(input_dim, out_dim)
854-
853+
dim_input = (2*r_max+2) + (2*r_max+2) + 1 + (2*s_max+2)
854+
self.out_embedder = LinearNoBias(dim_input, dim_out)
855+
855856
@typecheck
856857
def forward(
857858
self,
858859
*,
859860
additional_residue_feats: Float['b n rf']
860861
) -> Float['b n n dp']:
862+
863+
device = additional_residue_feats.device
864+
assert additional_residue_feats.shape[-1] >= 5
865+
866+
res_idx, token_idx, asym_id, entity_id, sym_id = additional_residue_feats[..., :5].unbind(dim = -1)
861867

862-
res_idx = additional_residue_feats[..., 0]
863-
token_idx = additional_residue_feats[..., 1]
864-
asym_id = additional_residue_feats[..., 2]
865-
entity_id = additional_residue_feats[..., 3]
866-
sym_id = additional_residue_feats[..., 4]
867-
868-
diff_res_idx = rearrange(res_idx, 'b n -> b n 1') \
869-
- rearrange(res_idx, 'b n -> b 1 n')
870-
diff_token_idx = rearrange(token_idx, 'b n -> b n 1') \
871-
- rearrange(token_idx, 'b n -> b 1 n')
872-
diff_sym_id = rearrange(sym_id, 'b n -> b n 1') \
873-
- rearrange(sym_id, 'b n -> b 1 n')
874-
mask_same_chain = rearrange(asym_id, 'b n -> b n 1') \
875-
- rearrange(asym_id, 'b n -> b 1 n') == 0
868+
diff_res_idx = einx.subtract('b i, b j -> b i j', res_idx, res_idx)
869+
diff_token_idx = einx.subtract('b i, b j -> b i j', token_idx, token_idx)
870+
diff_sym_id = einx.subtract('b i, b j -> b i j', sym_id, sym_id)
871+
872+
mask_same_chain = einx.subtract('b i, b j -> b i j', asym_id, asym_id) == 0
876873
mask_same_res = diff_res_idx == 0
877-
mask_same_entity = (rearrange(entity_id, 'b n -> b n 1') \
878-
- rearrange(entity_id, 'b n -> b 1 n') == 0).unsqueeze(-1)
874+
mask_same_entity = einx.subtract('b i, b j -> b i j 1', entity_id, entity_id) == 0
879875

880876
d_res = torch.where(
881877
mask_same_chain,
882878
torch.clip(diff_res_idx + self.r_max, 0, 2*self.r_max),
883879
2*self.r_max + 1
884880
)
881+
885882
d_token = torch.where(
886883
mask_same_chain * mask_same_res,
887884
torch.clip(diff_token_idx + self.r_max, 0, 2*self.r_max),
888885
2*self.r_max + 1
889886
)
887+
890888
d_chain = torch.where(
891889
~mask_same_chain,
892890
torch.clip(diff_sym_id + self.s_max, 0, 2*self.s_max),
893891
2*self.s_max + 1
894892
)
895893

896894
def onehot(x, bins):
897-
_, indexes = (x.view(-1, 1) - bins.view(1, -1)).abs().min(dim=1)
898-
indexes = indexes.type(torch.int64).view(-1, 1)
895+
x, packed_shape = pack_one(x, '*')
896+
dist_from_bins = einx.subtract('i, j -> i j', x, bins)
897+
indexes = dist_from_bins.abs().min(dim = 1, keepdim = True).indices
898+
indexes = rearrange(indexes.long(), 'i j -> (i j) 1')
899899
one_hots = torch.zeros(indexes.shape[0], len(bins)).scatter_(1, indexes, 1)
900-
out = rearrange(one_hots, '(b n k) d -> b n k d', n=x.shape[1], k=x.shape[2])
901-
return out
902-
903-
a_rel_pos = onehot(d_res, torch.arange(2*self.r_max + 2))
904-
a_rel_token = onehot(d_token, torch.arange(2*self.r_max + 2))
905-
a_rel_chain = onehot(d_chain, torch.arange(2*self.s_max + 2))
906-
907-
p = self.out_embedder(
908-
torch.cat([
909-
a_rel_pos,
910-
a_rel_token,
911-
mask_same_entity,
912-
a_rel_chain
913-
], dim=-1)
914-
)
915-
916-
return p
917-
918-
900+
return unpack_one(one_hots, packed_shape, '* d')
901+
902+
r_arange = torch.arange(2*self.r_max + 2, device = device)
903+
s_arange = torch.arange(2*self.s_max + 2, device = device)
904+
905+
a_rel_pos = onehot(d_res, r_arange)
906+
a_rel_token = onehot(d_token, r_arange)
907+
a_rel_chain = onehot(d_chain, s_arange)
908+
909+
out, _ = pack((
910+
a_rel_pos,
911+
a_rel_token,
912+
mask_same_entity,
913+
a_rel_chain
914+
), 'b i j *')
915+
916+
return self.out_embedder(out)
919917

920918
class TemplateEmbedder(Module):
921919
""" Algorithm 16 """
@@ -2021,6 +2019,10 @@ def __init__(
20212019
pair_bias_attn_heads = 16,
20222020
dropout_row_prob = 0.25,
20232021
pairwise_block_kwargs = dict()
2022+
),
2023+
relative_position_encoding_kwargs: dict = dict(
2024+
r_max = 32,
2025+
s_max = 2,
20242026
)
20252027
):
20262028
super().__init__()
@@ -2043,6 +2045,15 @@ def __init__(
20432045

20442046
dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
20452047

2048+
# relative positional encoding
2049+
# used by pairwise in main alphafold2 trunk
2050+
# and also in the diffusion module separately from alphafold3
2051+
2052+
self.relative_position_encoding = RelativePositionEncoding(
2053+
dim_out = dim_pairwise,
2054+
**relative_position_encoding_kwargs
2055+
)
2056+
20462057
# templates
20472058

20482059
self.template_embedder = TemplateEmbedder(
@@ -2130,6 +2141,8 @@ def forward(
21302141
resolved_labels: Int['b n'] | None = None,
21312142
) -> Float['b m 3'] | Float['']:
21322143

2144+
w = self.atoms_per_window
2145+
21332146
# embed inputs
21342147

21352148
(
@@ -2145,7 +2158,15 @@ def forward(
21452158
additional_residue_feats = additional_residue_feats
21462159
)
21472160

2148-
w = self.atoms_per_window
2161+
# relative positional encoding
2162+
2163+
relative_position_encoding = self.relative_position_encoding(
2164+
additional_residue_feats = additional_residue_feats
2165+
)
2166+
2167+
pairwise_init = pairwise_init + relative_position_encoding
2168+
2169+
# pairwise mask
21492170

21502171
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
21512172
pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask)

0 commit comments

Comments
 (0)