Skip to content

Commit c2c1eb7

Browse files
committed
simplify local atom attention to one lookback window, as one can just increase the window size, and makes atompair input construction much simpler
1 parent d1d7445 commit c2c1eb7

File tree

4 files changed

+24
-25
lines changed

4 files changed

+24
-25
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
slice_at_dim,
3333
pad_or_slice_to,
3434
pad_to_multiple,
35-
concat_neighboring_windows,
35+
concat_previous_window,
3636
full_attn_bias_to_windowed,
3737
full_pairwise_repr_to_windowed
3838
)
@@ -547,8 +547,8 @@ def forward(
547547
self,
548548
single_repr: Float['b n ds'],
549549
*,
550-
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*3) dp'],
551-
attn_bias: Float['b n n'] | Float['b nw w (w*3)'] | None = None,
550+
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
551+
attn_bias: Float['b n n'] | Float['b nw w (w*2)'] | None = None,
552552
**kwargs
553553
) -> Float['b n ds']:
554554

@@ -1450,7 +1450,7 @@ def forward(
14501450
noised_repr: Float['b n d'],
14511451
*,
14521452
single_repr: Float['b n ds'],
1453-
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*3) dp'],
1453+
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
14541454
mask: Bool['b n'] | None = None
14551455
):
14561456
w = self.attn_window_size
@@ -1691,7 +1691,7 @@ def forward(
16911691
noised_atom_pos: Float['b m 3'],
16921692
*,
16931693
atom_feats: Float['b m da'],
1694-
atompair_feats: Float['b m m dap'] | Float['b nw w (w*3) dap'],
1694+
atompair_feats: Float['b m m dap'] | Float['b nw w (w*2) dap'],
16951695
atom_mask: Bool['b m'],
16961696
times: Float[' b'],
16971697
mask: Bool['b n'],
@@ -1757,7 +1757,7 @@ def forward(
17571757
row_indices = rearrange(row_indices, 'b n w -> b n w 1', w = w)
17581758
col_indices = rearrange(col_indices, 'b n w -> b n 1 w', w = w)
17591759

1760-
col_indices = concat_neighboring_windows(col_indices, dim_seq = 1, dim_window = -1)
1760+
col_indices = concat_previous_window(col_indices, dim_seq = 1, dim_window = -1)
17611761
row_indices, col_indices = torch.broadcast_tensors(row_indices, col_indices)
17621762

17631763
pairwise_repr_cond = einx.get_at('b [i j] dap, b nw w1 w2, b nw w1 w2 -> b nw w1 w2 dap', pairwise_repr_cond, row_indices, col_indices)
@@ -1771,7 +1771,7 @@ def forward(
17711771

17721772
atom_repr_cond_row, atom_repr_cond_col = atom_repr_cond.chunk(2, dim = -1)
17731773

1774-
atom_repr_cond_col = concat_neighboring_windows(atom_repr_cond_col, dim_seq = 1, dim_window = 2)
1774+
atom_repr_cond_col = concat_previous_window(atom_repr_cond_col, dim_seq = 1, dim_window = 2)
17751775

17761776
atompair_feats = einx.add('b nw w1 w2 dap, b nw w1 dap -> b nw w1 w2 dap', atompair_feats, atom_repr_cond_row)
17771777
atompair_feats = einx.add('b nw w1 w2 dap, b nw w2 dap -> b nw w1 w2 dap', atompair_feats, atom_repr_cond_col)
@@ -2556,7 +2556,7 @@ def forward(
25562556
atom_feats_cond = pad_and_window(atom_feats_cond, w)
25572557

25582558
atom_feats_cond_row, atom_feats_cond_col = atom_feats_cond.chunk(2, dim = -1)
2559-
atom_feats_cond_col = concat_neighboring_windows(atom_feats_cond_col, dim_seq = 1, dim_window = -2)
2559+
atom_feats_cond_col = concat_previous_window(atom_feats_cond_col, dim_seq = 1, dim_window = -2)
25602560

25612561
atompair_feats = einx.add('b nw w1 w2 dap, b nw w1 dap',atompair_feats, atom_feats_cond_row)
25622562
atompair_feats = einx.add('b nw w1 w2 dap, b nw w2 dap',atompair_feats, atom_feats_cond_col)

alphafold3_pytorch/attention.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,17 @@ def pad_to_multiple(
9999
return pad_at_dim(t, (0, padding_needed), dim = dim, value = value)
100100

101101
@typecheck
102-
def concat_neighboring_windows(
102+
def concat_previous_window(
103103
t: Tensor,
104104
*,
105105
dim_seq: int,
106106
dim_window: int
107107
):
108-
t = pad_at_dim(t, (1, 1), dim = dim_seq, value = 0.)
108+
t = pad_at_dim(t, (1, 0), dim = dim_seq, value = 0.)
109109

110110
t = torch.cat((
111-
slice_at_dim(t, slice(None, -2), dim = dim_seq),
112-
slice_at_dim(t, slice(1, -1), dim = dim_seq),
113-
slice_at_dim(t, slice(2, None), dim = dim_seq)
111+
slice_at_dim(t, slice(None, -1), dim = dim_seq),
112+
slice_at_dim(t, slice(1, None), dim = dim_seq),
114113
), dim = dim_window)
115114

116115
return t
@@ -121,14 +120,14 @@ def concat_neighboring_windows(
121120
def full_pairwise_repr_to_windowed(
122121
pairwise_repr: Float['... m m dp'],
123122
window_size: int
124-
) -> Float['... n w (w*3) dp']:
123+
) -> Float['... n w (w*2) dp']:
125124

126125
seq_len, device = pairwise_repr.shape[-2], pairwise_repr.device
127126

128127
padding_needed = (window_size - (seq_len % window_size)) % window_size
129128
pairwise_repr = F.pad(pairwise_repr, (0, 0, 0, padding_needed, 0, padding_needed), value = 0.)
130129
pairwise_repr = rearrange(pairwise_repr, '... (i w1) (j w2) d -> ... i j w1 w2 d', w1 = window_size, w2 = window_size)
131-
pairwise_repr = concat_neighboring_windows(pairwise_repr, dim_seq = -4, dim_window = -2)
130+
pairwise_repr = concat_previous_window(pairwise_repr, dim_seq = -4, dim_window = -2)
132131

133132
# get the diagonal
134133

@@ -145,7 +144,7 @@ def full_pairwise_repr_to_windowed(
145144
def full_attn_bias_to_windowed(
146145
attn_bias: Float['... m m'],
147146
window_size: int
148-
) -> Float['... n w (w*3)']:
147+
) -> Float['... n w (w*2)']:
149148

150149
attn_bias = rearrange(attn_bias, '... -> ... 1')
151150
attn_bias = full_pairwise_repr_to_windowed(attn_bias, window_size = window_size)
@@ -215,7 +214,7 @@ def forward(
215214
seq: Float['b i d'],
216215
mask: Bool['b n']| None = None,
217216
context: Float['b j d'] | None = None,
218-
attn_bias: Float['... i j'] | Float['... nw w (w*3)'] | None = None
217+
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None
219218

220219
) -> Float['b i d']:
221220

@@ -316,7 +315,7 @@ def local_attn(
316315
k: Float['b h n d'],
317316
v: Float['b h n d'],
318317
mask: Bool['b n'] | None = None,
319-
attn_bias: Float['... n n'] | Float['... nw w (w*3)'] | None = None
318+
attn_bias: Float['... n n'] | Float['... nw w (w*2)'] | None = None
320319
) -> Float['b h n d']:
321320
"""
322321
simple local attention with a radius of 1 window size
@@ -345,11 +344,11 @@ def local_attn(
345344
# just do radius of 1 for now
346345
# perhaps not even necessary, and could try shifted windows (a la Swin)
347346

348-
k, v = tuple(pad_at_dim(t, (1, 1), dim = -2) for t in (k, v))
349-
mask = F.pad(mask, (1, 1), value = False)
347+
k, v = tuple(pad_at_dim(t, (1, 0), dim = -2) for t in (k, v))
348+
mask = F.pad(mask, (1, 0), value = False)
350349

351-
k, v = tuple(torch.cat((t[..., :-2, :], t[..., 1:-1, :], t[..., 2:, :]), dim = -2) for t in (k, v))
352-
mask = torch.cat((mask[..., :-2], mask[..., 1:-1], mask[..., 2:]), dim = -1)
350+
k, v = tuple(torch.cat((t[..., :-1, :], t[..., 1:, :]), dim = -2) for t in (k, v))
351+
mask = torch.cat((mask[..., :-1], mask[..., 1:]), dim = -1)
353352

354353
# handle attention bias (inefficiently)
355354

@@ -399,7 +398,7 @@ def forward(
399398
k: Float['b h j d'],
400399
v: Float['b h j d'],
401400
mask: Bool['b j'] | None = None,
402-
attn_bias: Float['... i j'] | Float['... nw w (w*3)'] | None = None,
401+
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
403402
) -> Float['b h i d']:
404403

405404
is_windowed_attn_bias = None

alphafold3_pytorch/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
class Alphafold3Input(TypedDict):
2626
atom_inputs: Float['m dai']
2727
residue_atom_lens: Int['n 2']
28-
atompair_inputs: Float['m m dapi'] | Float['nw w (w*3) dapi']
28+
atompair_inputs: Float['m m dapi'] | Float['nw w (w*2) dapi']
2929
additional_residue_feats: Float['n 10']
3030
templates: Float['t n n dt']
3131
msa: Float['s n dm']

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.2"
3+
version = "0.1.4"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)