Skip to content

Commit 4706bcd

Browse files
committed
just default to serial true and save a bunch of people headache
1 parent 5849d25 commit 4706bcd

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,7 @@ def __init__(
17681768
attn_num_memory_kv = False,
17691769
trans_expansion_factor = 2,
17701770
num_register_tokens = 0,
1771-
serial = False,
1771+
serial = True,
17721772
add_residual = True,
17731773
use_linear_attn = False,
17741774
checkpoint = False,
@@ -2112,7 +2112,7 @@ def __init__(
21122112
token_transformer_heads = 16,
21132113
atom_decoder_depth = 3,
21142114
atom_decoder_heads = 4,
2115-
serial = False,
2115+
serial = True,
21162116
atom_encoder_kwargs: dict = dict(),
21172117
atom_decoder_kwargs: dict = dict(),
21182118
token_transformer_kwargs: dict = dict(),
@@ -4276,7 +4276,6 @@ def __init__(
42764276
token_transformer_heads = 16,
42774277
atom_decoder_depth = 3,
42784278
atom_decoder_heads = 4,
4279-
serial = True # believe they have an error on Algorithm 23. lacking a residual - default to serial architecture until further news
42804279
),
42814280
edm_kwargs: dict = dict(
42824281
sigma_min = 0.002,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.67"
3+
version = "0.2.68"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,12 +683,22 @@ def test_alphafold3_without_msa_and_templates():
683683
depth = 1
684684
),
685685
pairformer_stack = dict(
686+
checkpoint = True,
686687
depth = 2
687688
),
688689
diffusion_module_kwargs = dict(
689-
atom_encoder_depth = 1,
690-
token_transformer_depth = 1,
691-
atom_decoder_depth = 1,
690+
atom_encoder_depth = 2,
691+
atom_encoder_kwargs = dict(
692+
checkpoint = True,
693+
),
694+
token_transformer_depth = 2,
695+
token_transformer_kwargs = dict(
696+
checkpoint = True,
697+
),
698+
atom_decoder_depth = 2,
699+
atom_decoder_kwargs = dict(
700+
checkpoint = True,
701+
),
692702
),
693703
)
694704

0 commit comments

Comments
 (0)