Skip to content

Commit 40cc79b

Browse files
authored
remove parallel, likely a paper error (#242)
1 parent 5591994 commit 40cc79b

File tree

2 files changed

+4
-58
lines changed

2 files changed

+4
-58
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,6 @@ def __init__(
18881888
attn_num_memory_kv = False,
18891889
trans_expansion_factor = 2,
18901890
num_register_tokens = 0,
1891-
serial = True,
18921891
add_residual = True,
18931892
use_linear_attn = False,
18941893
checkpoint = False,
@@ -1967,13 +1966,10 @@ def __init__(
19671966
conditionable_transition
19681967
]))
19691968

1970-
assert not (not serial and checkpoint), 'checkpointing can only be used for serial version of diffusion transformer'
1971-
19721969
self.checkpoint = checkpoint
19731970

19741971
self.layers = layers
19751972

1976-
self.serial = serial
19771973
self.add_residual = add_residual
19781974

19791975
self.has_registers = num_register_tokens > 0
@@ -2074,48 +2070,6 @@ def to_serial_layers(
20742070

20752071
return noised_repr
20762072

2077-
@typecheck
2078-
def to_parallel_layers(
2079-
self,
2080-
noised_repr: Float['b n d'],
2081-
*,
2082-
single_repr: Float['b n ds'],
2083-
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
2084-
mask: Bool['b n'] | None = None,
2085-
windowed_mask: Bool['b nw w (w*2)'] | None = None
2086-
):
2087-
2088-
for linear_attn, colt5_attn, attn, transition in self.layers:
2089-
2090-
if exists(linear_attn):
2091-
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
2092-
2093-
if exists(colt5_attn):
2094-
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
2095-
2096-
attn_out = attn(
2097-
noised_repr,
2098-
cond = single_repr,
2099-
pairwise_repr = pairwise_repr,
2100-
mask = mask,
2101-
windowed_mask = windowed_mask
2102-
)
2103-
2104-
ff_out = transition(
2105-
noised_repr,
2106-
cond = single_repr
2107-
)
2108-
2109-
# in the algorithm, they omitted the residual, but it could be an error
2110-
# attn + ff + residual was used in GPT-J and PaLM, but later found to be unstable configuration, so it seems unlikely attn + ff would work
2111-
# but in the case they figured out something we have not, you can use their exact formulation by setting `serial = False` and `add_residual = False`
2112-
2113-
residual = noised_repr if self.add_residual else 0.
2114-
2115-
noised_repr = ff_out + attn_out + residual
2116-
2117-
return noised_repr
2118-
21192073
@typecheck
21202074
def forward(
21212075
self,
@@ -2126,7 +2080,7 @@ def forward(
21262080
mask: Bool['b n'] | None = None,
21272081
windowed_mask: Bool['b nw w (w*2)'] | None = None
21282082
):
2129-
w, serial = self.attn_window_size, self.serial
2083+
w = self.attn_window_size
21302084
has_windows = exists(w)
21312085

21322086
# handle windowing
@@ -2151,12 +2105,10 @@ def forward(
21512105

21522106
# main transformer
21532107

2154-
if serial and should_checkpoint(self, (noised_repr, single_repr, pairwise_repr)):
2108+
if should_checkpoint(self, (noised_repr, single_repr, pairwise_repr)):
21552109
to_layers_fn = self.to_checkpointed_serial_layers
2156-
elif serial:
2157-
to_layers_fn = self.to_serial_layers
21582110
else:
2159-
to_layers_fn = self.to_parallel_layers
2111+
to_layers_fn = self.to_serial_layers
21602112

21612113
noised_repr = to_layers_fn(
21622114
noised_repr,
@@ -2230,7 +2182,6 @@ def __init__(
22302182
token_transformer_heads = 16,
22312183
atom_decoder_depth = 3,
22322184
atom_decoder_heads = 4,
2233-
serial = True,
22342185
atom_encoder_kwargs: dict = dict(),
22352186
atom_decoder_kwargs: dict = dict(),
22362187
token_transformer_kwargs: dict = dict(),
@@ -2298,7 +2249,6 @@ def __init__(
22982249
attn_window_size = atoms_per_window,
22992250
depth = atom_encoder_depth,
23002251
heads = atom_encoder_heads,
2301-
serial = serial,
23022252
use_linear_attn = use_linear_attn,
23032253
linear_attn_kwargs = linear_attn_kwargs,
23042254
checkpoint = checkpoint,
@@ -2323,7 +2273,6 @@ def __init__(
23232273
dim_pairwise = dim_pairwise,
23242274
depth = token_transformer_depth,
23252275
heads = token_transformer_heads,
2326-
serial = serial,
23272276
checkpoint = checkpoint,
23282277
**token_transformer_kwargs
23292278
)
@@ -2341,7 +2290,6 @@ def __init__(
23412290
attn_window_size = atoms_per_window,
23422291
depth = atom_decoder_depth,
23432292
heads = atom_decoder_heads,
2344-
serial = serial,
23452293
use_linear_attn = use_linear_attn,
23462294
linear_attn_kwargs = linear_attn_kwargs,
23472295
checkpoint = checkpoint,

tests/test_af3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,11 @@ def test_msa_module(
365365
loss = pairwise_out.sum()
366366
loss.backward()
367367

368-
@pytest.mark.parametrize('serial,checkpoint', ((False, False), (True, False), (True, True)))
368+
@pytest.mark.parametrize('checkpoint', (False, True))
369369
@pytest.mark.parametrize('use_linear_attn', (False, True))
370370
@pytest.mark.parametrize('use_colt5_attn', (False, True))
371371
def test_diffusion_transformer(
372372
checkpoint,
373-
serial,
374373
use_linear_attn,
375374
use_colt5_attn
376375
):
@@ -382,7 +381,6 @@ def test_diffusion_transformer(
382381
diffusion_transformer = DiffusionTransformer(
383382
depth = 2,
384383
heads = 16,
385-
serial = serial,
386384
checkpoint = checkpoint,
387385
use_linear_attn = use_linear_attn,
388386
use_colt5_attn = use_colt5_attn

0 commit comments

Comments
 (0)