Skip to content

Commit 8f9e41d

Browse files
committed
new strategy for repeat consecutive, to get ready for more efficient local attention atom transformer biasing
1 parent 74f062a commit 8f9e41d

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,7 @@ def repeat_consecutive_with_lens(
151151
lens: Int['b n'],
152152
) -> Float['b m ...'] | Bool['b m']:
153153

154-
is_bool = feats.dtype == torch.bool
155-
feats = feats.float()
156-
157-
device = feats.device
154+
device, dtype = feats.device, feats.dtype
158155

159156
batch, seq, *dims = feats.shape
160157

@@ -174,25 +171,38 @@ def repeat_consecutive_with_lens(
174171
# create output tensor + a sink position on the very right (index max_len)
175172

176173
total_lens = lens.sum(dim = -1)
174+
output_mask = lens_to_mask(total_lens)
175+
177176
max_len = total_lens.amax()
178177

179-
output = torch.zeros((batch, max_len + 1, *dims), device = device)
178+
output_indices = torch.zeros((batch, max_len + 1), device = device, dtype = torch.long)
180179

181180
indices.masked_fill_(~mask, max_len) # scatter to sink position for padding
182181
indices = rearrange(indices, 'b n w -> b (n w)')
183182

184-
feats = repeat(feats, 'b n ... -> b (n w) ...', w = window_size)
185-
186183
# scatter
187184

188-
output = einx.set_at('b [m] ..., b nw, b nw ... -> b [m] ...', output, indices, feats)
185+
seq_arange = torch.arange(seq, device = device)
186+
seq_arange = repeat(seq_arange, 'n -> (n w)', w = window_size)
187+
188+
output_indices = einx.set_at('b [m], b nw, nw -> b [m]', output_indices, indices, seq_arange)
189189

190190
# remove sink
191191

192-
output = output[:, :-1]
192+
output_indices = output_indices[:, :-1]
193+
194+
# gather
195+
196+
output = einx.get_at('b [n] ..., b m -> b m ...', feats, output_indices)
197+
198+
# final mask
199+
200+
mask_value = False if dtype == torch.bool else 0
193201

194-
if is_bool:
195-
output = output.bool()
202+
output = einx.where(
203+
'b n, b n ..., -> b n ...',
204+
output_mask, output, mask_value
205+
)
196206

197207
return output
198208

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.48"
3+
version = "0.0.49"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)