Skip to content

Commit b0ac11d

Browse files
authored
Merge pull request #6 from joseph-c-kim/rpe
Relative Position Encoding (Algorithm 3)
2 parents 343f6fa + 9339326 commit b0ac11d

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

alphafold3_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55

66
from alphafold3_pytorch.alphafold3 import (
7+
RelativePositionEncoding,
78
TemplateEmbedder,
89
PreLayerNorm,
910
AdaptiveLayerNorm,
@@ -28,6 +29,7 @@
2829
__all__ = [
2930
Attention,
3031
Attend,
32+
RelativePositionEncoding,
3133
TemplateEmbedder,
3234
PreLayerNorm,
3335
AdaptiveLayerNorm,

alphafold3_pytorch/alphafold3.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
ds - feature dimension (single)
1212
dp - feature dimension (pairwise)
1313
dap - feature dimension (atompair)
14-
da - feature dimensino (atom)
14+
da - feature dimension (atom)
1515
t - templates
16-
m - msa
16+
s - msa
1717
"""
1818

1919
from __future__ import annotations
@@ -822,6 +822,101 @@ def forward(
822822

823823
# embedding related
824824

825+
"""
826+
additional_residue_feats: [*, rf]:
827+
0: residue_index
828+
1: token_index
829+
2: asym_id
830+
3: entity_id
831+
4: sym_id
832+
5: restype (must be one hot encoded to 32)
833+
6: is_protein
834+
7: is_rna
835+
8: is_dna
836+
9: is_ligand
837+
"""
838+
839+
class RelativePositionEncoding(Module):
840+
""" Algorithm 3 """
841+
842+
def __init__(
843+
self,
844+
r_max = 32,
845+
s_max = 2,
846+
out_dim = 128
847+
):
848+
super().__init__()
849+
self.r_max = r_max
850+
self.s_max = s_max
851+
852+
input_dim = (2*r_max+2) + (2*r_max+2) + 1 + (2*s_max+2)
853+
self.out_embedder = LinearNoBias(input_dim, out_dim)
854+
855+
@typecheck
856+
def forward(
857+
self,
858+
*,
859+
additional_residue_feats: Float['b n rf']
860+
) -> Float['b n n dp']:
861+
862+
res_idx = additional_residue_feats[..., 0]
863+
token_idx = additional_residue_feats[..., 1]
864+
asym_id = additional_residue_feats[..., 2]
865+
entity_id = additional_residue_feats[..., 3]
866+
sym_id = additional_residue_feats[..., 4]
867+
868+
diff_res_idx = rearrange(res_idx, 'b n -> b n 1') \
869+
- rearrange(res_idx, 'b n -> b 1 n')
870+
diff_token_idx = rearrange(token_idx, 'b n -> b n 1') \
871+
- rearrange(token_idx, 'b n -> b 1 n')
872+
diff_sym_id = rearrange(sym_id, 'b n -> b n 1') \
873+
- rearrange(sym_id, 'b n -> b 1 n')
874+
mask_same_chain = rearrange(asym_id, 'b n -> b n 1') \
875+
- rearrange(asym_id, 'b n -> b 1 n') == 0
876+
mask_same_res = diff_res_idx == 0
877+
mask_same_entity = (rearrange(entity_id, 'b n -> b n 1') \
878+
- rearrange(entity_id, 'b n -> b 1 n') == 0).unsqueeze(-1)
879+
880+
d_res = torch.where(
881+
mask_same_chain,
882+
torch.clip(diff_res_idx + self.r_max, 0, 2*self.r_max),
883+
2*self.r_max + 1
884+
)
885+
d_token = torch.where(
886+
mask_same_chain * mask_same_res,
887+
torch.clip(diff_token_idx + self.r_max, 0, 2*self.r_max),
888+
2*self.r_max + 1
889+
)
890+
d_chain = torch.where(
891+
~mask_same_chain,
892+
torch.clip(diff_sym_id + self.s_max, 0, 2*self.s_max),
893+
2*self.s_max + 1
894+
)
895+
896+
def onehot(x, bins):
897+
_, indexes = (x.view(-1, 1) - bins.view(1, -1)).abs().min(dim=1)
898+
indexes = indexes.type(torch.int64).view(-1, 1)
899+
one_hots = torch.zeros(indexes.shape[0], len(bins)).scatter_(1, indexes, 1)
900+
out = rearrange(one_hots, '(b n k) d -> b n k d', n=x.shape[1], k=x.shape[2])
901+
return out
902+
903+
a_rel_pos = onehot(d_res, torch.arange(2*self.r_max + 2))
904+
a_rel_token = onehot(d_token, torch.arange(2*self.r_max + 2))
905+
a_rel_chain = onehot(d_chain, torch.arange(2*self.s_max + 2))
906+
907+
p = self.out_embedder(
908+
torch.cat([
909+
a_rel_pos,
910+
a_rel_token,
911+
mask_same_entity,
912+
a_rel_chain
913+
], dim=-1)
914+
)
915+
916+
return p
917+
918+
919+
825920
class TemplateEmbedder(Module):
826921
""" Algorithm 16 """
827922

tests/test_readme.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
DiffusionTransformer,
1111
DiffusionModule,
1212
ElucidatedAtomDiffusion,
13+
RelativePositionEncoding,
1314
TemplateEmbedder,
1415
Attention,
1516
InputFeatureEmbedder,
1617
ConfidenceHead,
1718
DistogramHead,
18-
Alphafold3
19+
Alphafold3,
1920
)
2021

2122
def test_pairformer():
@@ -162,6 +163,15 @@ def test_diffusion_module():
162163
)
163164

164165
assert sampled_atom_pos.shape == noised_atom_pos.shape
166+
167+
def test_relative_position_encoding():
168+
additional_residue_feats = torch.randn(8, 100, 10)
169+
170+
embedder = RelativePositionEncoding()
171+
172+
rpe_embed = embedder(
173+
additional_residue_feats = additional_residue_feats
174+
)
165175

166176
def test_template_embed():
167177
template_feats = torch.randn(2, 2, 16, 16, 77)

0 commit comments

Comments
 (0)