|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -""" |
4 | | -global ein notation: |
5 | | -
|
6 | | -b - batch |
7 | | -ba - batch with augmentation |
8 | | -h - heads |
9 | | -n - residue sequence length |
10 | | -i - residue sequence length (source) |
11 | | -j - residue sequence length (target) |
12 | | -m - atom sequence length |
13 | | -nw - windowed sequence length |
14 | | -d - feature dimension |
15 | | -ds - feature dimension (single) |
16 | | -dp - feature dimension (pairwise) |
17 | | -dap - feature dimension (atompair) |
18 | | -dapi - feature dimension (atompair input) |
19 | | -da - feature dimension (atom) |
20 | | -dai - feature dimension (atom input) |
21 | | -t - templates |
22 | | -s - msa |
23 | | -r - registers |
24 | | -""" |
25 | | - |
26 | | -""" |
27 | | -additional_residue_feats: [*, 10]: |
28 | | -
|
29 | | -0: residue_index |
30 | | -1: token_index |
31 | | -2: asym_id |
32 | | -3: entity_id |
33 | | -4: sym_id |
34 | | -5: restype (must be one hot encoded to 32) |
35 | | -6: is_protein |
36 | | -7: is_rna |
37 | | -8: is_dna |
38 | | -9: is_ligand |
39 | | -""" |
40 | | - |
41 | 3 | from math import pi, sqrt |
42 | 4 | from pathlib import Path |
43 | 5 | from functools import partial, wraps |
|
85 | 47 |
|
86 | 48 | from importlib.metadata import version |
87 | 49 |
|
| 50 | +""" |
| 51 | +global ein notation: |
| 52 | +
|
| 53 | +b - batch |
| 54 | +ba - batch with augmentation |
| 55 | +h - heads |
| 56 | +n - residue sequence length |
| 57 | +i - residue sequence length (source) |
| 58 | +j - residue sequence length (target) |
| 59 | +m - atom sequence length |
| 60 | +nw - windowed sequence length |
| 61 | +d - feature dimension |
| 62 | +ds - feature dimension (single) |
| 63 | +dp - feature dimension (pairwise) |
| 64 | +dap - feature dimension (atompair) |
| 65 | +dapi - feature dimension (atompair input) |
| 66 | +da - feature dimension (atom) |
| 67 | +dai - feature dimension (atom input) |
| 68 | +t - templates |
| 69 | +s - msa |
| 70 | +r - registers |
| 71 | +""" |
| 72 | + |
| 73 | +""" |
| 74 | +additional_residue_feats: [*, 10]: |
| 75 | +
|
| 76 | +0: residue_index |
| 77 | +1: token_index |
| 78 | +2: asym_id |
| 79 | +3: entity_id |
| 80 | +4: sym_id |
| 81 | +5: restype (must be one hot encoded to 32) |
| 82 | +6: is_protein |
| 83 | +7: is_rna |
| 84 | +8: is_dna |
| 85 | +9: is_ligand |
| 86 | +""" |
| 87 | + |
88 | 88 | # constants |
89 | 89 |
|
90 | 90 | ADDITIONAL_RESIDUE_FEATS = 10 |
@@ -241,24 +241,6 @@ def repeat_consecutive_with_lens( |
241 | 241 |
|
242 | 242 | return output |
243 | 243 |
|
244 | | -def repeat_pairwise_consecutive_with_lens( |
245 | | - feats: Float['b n n dp'], |
246 | | - lens: Int['b n'] |
247 | | -) -> Float['b m m dp']: |
248 | | - |
249 | | - repeated_lens = repeat(lens, 'b ... -> (b repeat) ...', repeat = feats.shape[1]) |
250 | | - feats, ps = pack_one(feats, '* n dp') |
251 | | - feats = repeat_consecutive_with_lens(feats, repeated_lens) |
252 | | - feats = unpack_one(feats, ps, '* n dp') |
253 | | - |
254 | | - feats = rearrange(feats, 'b i j dp -> b j i dp') |
255 | | - repeated_lens = repeat(lens, 'b ... -> (b repeat) ...', repeat = feats.shape[1]) |
256 | | - feats, ps = pack_one(feats, '* n dp') |
257 | | - feats = repeat_consecutive_with_lens(feats, repeated_lens) |
258 | | - feats = unpack_one(feats, ps, '* n dp') |
259 | | - feats = rearrange(feats, 'b j i dp -> b i j dp') |
260 | | - return feats |
261 | | - |
262 | 244 | # linear and outer sum |
263 | 245 | # for single repr -> pairwise pattern throughout this architecture |
264 | 246 |
|
@@ -2105,8 +2087,6 @@ def forward( |
2105 | 2087 | align_weights = atom_pos_ground_truth.new_ones(atom_pos_ground_truth.shape[:2]) |
2106 | 2088 |
|
2107 | 2089 | if exists(additional_residue_feats): |
2108 | | - w = self.net.atoms_per_window |
2109 | | - |
2110 | 2090 | is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1) |
2111 | 2091 |
|
2112 | 2092 | is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, residue_atom_lens) for t in is_nucleotide_or_ligand_fields) |
|
0 commit comments