Skip to content

Commit c0a98f8

Browse files
committed
architecture is there
1 parent 23da34f commit c0a98f8

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<img src="./alphafold3.png" width="500px"></img>
22

3-
## Alphafold 3 - Pytorch (wip)
3+
## Alphafold 3 - Pytorch
44

55
Implementation of <a href="https://www.nature.com/articles/s41586-024-07487-w">Alphafold 3</a> in Pytorch
66

alphafold3_pytorch/alphafold3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3077,6 +3077,13 @@ def forward(
30773077

30783078
assert exists(residue_atom_lens) or exists(atom_mask)
30793079

3080+
# if atompair inputs are not windowed, window it
3081+
3082+
is_atompair_inputs_windowed = atompair_inputs.ndim == 5
3083+
3084+
if not is_atompair_inputs_windowed:
3085+
atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = self.atoms_per_window)
3086+
30803087
# handle atom mask
30813088

30823089
total_atoms = residue_atom_lens.sum(dim = -1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.69"
3+
version = "0.1.0"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
from alphafold3_pytorch.alphafold3 import (
2828
mean_pool_with_lens,
29-
repeat_consecutive_with_lens
29+
repeat_consecutive_with_lens,
30+
full_pairwise_repr_to_windowed
3031
)
3132

3233
def test_mean_pool_with_lens():
@@ -383,17 +384,25 @@ def test_distogram_head():
383384

384385
logits = distogram_head(pairwise_repr)
385386

386-
387-
def test_alphafold3():
387+
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
388+
def test_alphafold3(
389+
window_atompair_inputs: bool
390+
):
388391
seq_len = 16
392+
atoms_per_window = 27
393+
389394
residue_atom_lens = torch.randint(1, 3, (2, seq_len))
390395
atom_seq_len = residue_atom_lens.sum(dim = -1).amax()
391396

392397
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
393398

394399
atom_inputs = torch.randn(2, atom_seq_len, 77)
400+
395401
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
396402

403+
if window_atompair_inputs:
404+
atompair_inputs = full_pairwise_repr_to_windowed(atompair_inputs, window_size = atoms_per_window)
405+
397406
additional_residue_feats = torch.randn(2, seq_len, 10)
398407

399408
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -412,6 +421,7 @@ def test_alphafold3():
412421

413422
alphafold3 = Alphafold3(
414423
dim_atom_inputs = 77,
424+
atoms_per_window = atoms_per_window,
415425
dim_template_feats = 44,
416426
num_dist_bins = 38,
417427
confidence_head_kwargs = dict(

0 commit comments

Comments
 (0)