Skip to content

Commit 343f6fa

Browse files
committed
directly contract for spatial coordinates
1 parent e2e8e46 commit 343f6fa

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
i - residue sequence length (source)
88
j - residue sequence length (target)
99
m - atom sequence length
10-
c - coordinates (3 for spatial)
1110
d - feature dimension
1211
ds - feature dimension (single)
1312
dp - feature dimension (pairwise)
@@ -1292,7 +1291,7 @@ def __init__(
12921291
@typecheck
12931292
def forward(
12941293
self,
1295-
noised_atom_pos: Float['b m c'],
1294+
noised_atom_pos: Float['b m 3'],
12961295
*,
12971296
atom_feats: Float['b m da'],
12981297
atompair_feats: Float['b m m dap'],
@@ -1465,7 +1464,7 @@ def c_noise(self, sigma):
14651464
@typecheck
14661465
def preconditioned_network_forward(
14671466
self,
1468-
noised_atom_pos: Float['b m c'],
1467+
noised_atom_pos: Float['b m 3'],
14691468
sigma: Float[' b'] | Float[' '] | float,
14701469
network_condition_kwargs: dict,
14711470
clamp = False,
@@ -1815,7 +1814,7 @@ def forward(
18151814
single_inputs_repr: Float['b n dsi'],
18161815
single_repr: Float['b n ds'],
18171816
pairwise_repr: Float['b n n dp'],
1818-
pred_atom_pos: Float['b n c'],
1817+
pred_atom_pos: Float['b n 3'],
18191818
mask: Bool['b n'] | None = None,
18201819
return_pae_logits = True
18211820

@@ -2034,7 +2033,7 @@ def forward(
20342033
pde_labels: Int['b n n'] | None = None,
20352034
plddt_labels: Int['b n'] | None = None,
20362035
resolved_labels: Int['b n'] | None = None,
2037-
) -> Float['b m c'] | Float['']:
2036+
) -> Float['b m 3'] | Float['']:
20382037

20392038
# embed inputs
20402039

0 commit comments

Comments
 (0)