|
7 | 7 | i - residue sequence length (source) |
8 | 8 | j - residue sequence length (target) |
9 | 9 | m - atom sequence length |
10 | | -c - coordinates (3 for spatial) |
11 | 10 | d - feature dimension |
12 | 11 | ds - feature dimension (single) |
13 | 12 | dp - feature dimension (pairwise) |
@@ -1292,7 +1291,7 @@ def __init__( |
1292 | 1291 | @typecheck |
1293 | 1292 | def forward( |
1294 | 1293 | self, |
1295 | | - noised_atom_pos: Float['b m c'], |
| 1294 | + noised_atom_pos: Float['b m 3'], |
1296 | 1295 | *, |
1297 | 1296 | atom_feats: Float['b m da'], |
1298 | 1297 | atompair_feats: Float['b m m dap'], |
@@ -1465,7 +1464,7 @@ def c_noise(self, sigma): |
1465 | 1464 | @typecheck |
1466 | 1465 | def preconditioned_network_forward( |
1467 | 1466 | self, |
1468 | | - noised_atom_pos: Float['b m c'], |
| 1467 | + noised_atom_pos: Float['b m 3'], |
1469 | 1468 | sigma: Float[' b'] | Float[' '] | float, |
1470 | 1469 | network_condition_kwargs: dict, |
1471 | 1470 | clamp = False, |
@@ -1815,7 +1814,7 @@ def forward( |
1815 | 1814 | single_inputs_repr: Float['b n dsi'], |
1816 | 1815 | single_repr: Float['b n ds'], |
1817 | 1816 | pairwise_repr: Float['b n n dp'], |
1818 | | - pred_atom_pos: Float['b n c'], |
| 1817 | + pred_atom_pos: Float['b n 3'], |
1819 | 1818 | mask: Bool['b n'] | None = None, |
1820 | 1819 | return_pae_logits = True |
1821 | 1820 |
|
@@ -2034,7 +2033,7 @@ def forward( |
2034 | 2033 | pde_labels: Int['b n n'] | None = None, |
2035 | 2034 | plddt_labels: Int['b n'] | None = None, |
2036 | 2035 | resolved_labels: Int['b n'] | None = None, |
2037 | | - ) -> Float['b m c'] | Float['']: |
| 2036 | + ) -> Float['b m 3'] | Float['']: |
2038 | 2037 |
|
2039 | 2038 | # embed inputs |
2040 | 2039 |
|
|
0 commit comments