|
11 | 11 | ds - feature dimension (single) |
12 | 12 | dp - feature dimension (pairwise) |
13 | 13 | dap - feature dimension (atompair) |
14 | | -da - feature dimensino (atom) |
| 14 | +da - feature dimension (atom) |
15 | 15 | t - templates |
16 | | -m - msa |
| 16 | +s - msa |
17 | 17 | """ |
18 | 18 |
|
19 | 19 | from __future__ import annotations |
@@ -822,6 +822,101 @@ def forward( |
822 | 822 |
|
823 | 823 | # embedding related |
824 | 824 |
|
| 825 | +""" |
| 826 | +additional_residue_feats: [*, rf]: |
| 827 | + 0: residue_index |
| 828 | + 1: token_index |
| 829 | + 2: asym_id |
| 830 | + 3: entity_id |
| 831 | + 4: sym_id |
| 832 | + 5: restype (must be one hot encoded to 32) |
| 833 | + 6: is_protein |
| 834 | + 7: is_rna |
| 835 | + 8: is_dna |
| 836 | + 9: is_ligand |
| 837 | +""" |
| 838 | + |
| 839 | +class RelativePositionEncoding(Module): |
| 840 | + """ Algorithm 3 """ |
| 841 | + |
| 842 | + def __init__( |
| 843 | + self, |
| 844 | + r_max = 32, |
| 845 | + s_max = 2, |
| 846 | + out_dim = 128 |
| 847 | + ): |
| 848 | + super().__init__() |
| 849 | + self.r_max = r_max |
| 850 | + self.s_max = s_max |
| 851 | + |
| 852 | + input_dim = (2*r_max+2) + (2*r_max+2) + 1 + (2*s_max+2) |
| 853 | + self.out_embedder = LinearNoBias(input_dim, out_dim) |
| 854 | + |
| 855 | + @typecheck |
| 856 | + def forward( |
| 857 | + self, |
| 858 | + *, |
| 859 | + additional_residue_feats: Float['b n rf'] |
| 860 | + ) -> Float['b n n dp']: |
| 861 | + |
| 862 | + res_idx = additional_residue_feats[..., 0] |
| 863 | + token_idx = additional_residue_feats[..., 1] |
| 864 | + asym_id = additional_residue_feats[..., 2] |
| 865 | + entity_id = additional_residue_feats[..., 3] |
| 866 | + sym_id = additional_residue_feats[..., 4] |
| 867 | + |
| 868 | + diff_res_idx = rearrange(res_idx, 'b n -> b n 1') \ |
| 869 | + - rearrange(res_idx, 'b n -> b 1 n') |
| 870 | + diff_token_idx = rearrange(token_idx, 'b n -> b n 1') \ |
| 871 | + - rearrange(token_idx, 'b n -> b 1 n') |
| 872 | + diff_sym_id = rearrange(sym_id, 'b n -> b n 1') \ |
| 873 | + - rearrange(sym_id, 'b n -> b 1 n') |
| 874 | + mask_same_chain = rearrange(asym_id, 'b n -> b n 1') \ |
| 875 | + - rearrange(asym_id, 'b n -> b 1 n') == 0 |
| 876 | + mask_same_res = diff_res_idx == 0 |
| 877 | + mask_same_entity = (rearrange(entity_id, 'b n -> b n 1') \ |
| 878 | + - rearrange(entity_id, 'b n -> b 1 n') == 0).unsqueeze(-1) |
| 879 | + |
| 880 | + d_res = torch.where( |
| 881 | + mask_same_chain, |
| 882 | + torch.clip(diff_res_idx + self.r_max, 0, 2*self.r_max), |
| 883 | + 2*self.r_max + 1 |
| 884 | + ) |
| 885 | + d_token = torch.where( |
| 886 | + mask_same_chain * mask_same_res, |
| 887 | + torch.clip(diff_token_idx + self.r_max, 0, 2*self.r_max), |
| 888 | + 2*self.r_max + 1 |
| 889 | + ) |
| 890 | + d_chain = torch.where( |
| 891 | + ~mask_same_chain, |
| 892 | + torch.clip(diff_sym_id + self.s_max, 0, 2*self.s_max), |
| 893 | + 2*self.s_max + 1 |
| 894 | + ) |
| 895 | + |
| 896 | + def onehot(x, bins): |
| 897 | + _, indexes = (x.view(-1, 1) - bins.view(1, -1)).abs().min(dim=1) |
| 898 | + indexes = indexes.type(torch.int64).view(-1, 1) |
| 899 | + one_hots = torch.zeros(indexes.shape[0], len(bins)).scatter_(1, indexes, 1) |
| 900 | + out = rearrange(one_hots, '(b n k) d -> b n k d', n=x.shape[1], k=x.shape[2]) |
| 901 | + return out |
| 902 | + |
| 903 | + a_rel_pos = onehot(d_res, torch.arange(2*self.r_max + 2)) |
| 904 | + a_rel_token = onehot(d_token, torch.arange(2*self.r_max + 2)) |
| 905 | + a_rel_chain = onehot(d_chain, torch.arange(2*self.s_max + 2)) |
| 906 | + |
| 907 | + p = self.out_embedder( |
| 908 | + torch.cat([ |
| 909 | + a_rel_pos, |
| 910 | + a_rel_token, |
| 911 | + mask_same_entity, |
| 912 | + a_rel_chain |
| 913 | + ], dim=-1) |
| 914 | + ) |
| 915 | + |
| 916 | + return p |
| 917 | + |
| 918 | + |
| 919 | + |
825 | 920 | class TemplateEmbedder(Module): |
826 | 921 | """ Algorithm 16 """ |
827 | 922 |
|
|
0 commit comments