Skip to content

Commit 23da34f

Browse files
committed
complete ability to pass in atompair inputs windowed, for efficient attention biasing in the atom transformer.
1 parent 3a456b2 commit 23da34f

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,7 +2503,7 @@ def __init__(
25032503

25042504
self.atom_repr_to_atompair_feat_cond = nn.Sequential(
25052505
nn.LayerNorm(dim_atom),
2506-
LinearNoBiasThenOuterSum(dim_atom, dim_atompair),
2506+
LinearNoBias(dim_atom, dim_atompair * 2),
25072507
nn.ReLU()
25082508
)
25092509

@@ -2540,7 +2540,7 @@ def forward(
25402540
self,
25412541
*,
25422542
atom_inputs: Float['b m dai'],
2543-
atompair_inputs: Float['b m m dapi'],
2543+
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
25442544
atom_mask: Bool['b m'],
25452545
additional_residue_feats: Float[f'b n {ADDITIONAL_RESIDUE_FEATS}'],
25462546
residue_atom_lens: Int['b n'],
@@ -2554,19 +2554,32 @@ def forward(
25542554
atom_feats = self.to_atom_feats(atom_inputs)
25552555
atompair_feats = self.to_atompair_feats(atompair_inputs)
25562556

2557+
# window the atom pair features before passing to atom encoder and decoder
2558+
2559+
is_windowed = atompair_inputs.ndim == 5
2560+
2561+
if not is_windowed:
2562+
atompair_feats = full_pairwise_repr_to_windowed(atompair_feats, window_size = w)
2563+
2564+
# condition atompair with atom repr
2565+
25572566
atom_feats_cond = self.atom_repr_to_atompair_feat_cond(atom_feats)
2558-
atompair_feats = atom_feats_cond + atompair_feats
25592567

2560-
# window the atom pair features before passing to atom encoder and decoder
2568+
atom_feats_cond = pad_to_multiple(atom_feats_cond, w, dim = 1)
2569+
atom_feats_cond = rearrange(atom_feats_cond, 'b (nw w) dap -> b nw w dap', w = w)
25612570

2562-
windowed_atompair_feats = full_pairwise_repr_to_windowed(atompair_feats, window_size = w)
2571+
atom_feats_cond_row, atom_feats_cond_col = atom_feats_cond.chunk(2, dim = -1)
2572+
atom_feats_cond_col = concat_neighboring_windows(atom_feats_cond_col, dim_seq = 1, dim_window = -2)
2573+
2574+
atompair_feats = einx.add('b nw w1 w2 dap, b nw w1 dap',atompair_feats, atom_feats_cond_row)
2575+
atompair_feats = einx.add('b nw w1 w2 dap, b nw w2 dap',atompair_feats, atom_feats_cond_col)
25632576

25642577
# initial atom transformer
25652578

25662579
atom_feats = self.atom_transformer(
25672580
atom_feats,
25682581
single_repr = atom_feats,
2569-
pairwise_repr = windowed_atompair_feats
2582+
pairwise_repr = atompair_feats
25702583
)
25712584

25722585
atompair_feats = self.atompair_feats_mlp(atompair_feats) + atompair_feats
@@ -3035,7 +3048,7 @@ def forward(
30353048
self,
30363049
*,
30373050
atom_inputs: Float['b m dai'],
3038-
atompair_inputs: Float['b m m dapi'],
3051+
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
30393052
additional_residue_feats: Float[f'b n {ADDITIONAL_RESIDUE_FEATS}'],
30403053
residue_atom_lens: Int['b n'],
30413054
atom_mask: Bool['b m'] | None = None,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.68"
3+
version = "0.0.69"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)