Skip to content

Commit cf43bf2

Browse files
committed
moer efficient token leve -> atom level for packed atom rep
1 parent 9eea924 commit cf43bf2

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,15 @@ def inner(t, *args, **kwargs):
9191

9292
@typecheck
9393
def lens_to_mask(
94-
lens: Int['b n'] | Int[' b']
95-
) -> Bool['b m']:
94+
lens: Int['b ...'],
95+
max_len: int | None = None
96+
) -> Bool['... m']:
9697

9798
device = lens.device
98-
99-
if lens.ndim == 2:
100-
lens = reduce(lens, 'b m -> b', 'sum')
101-
102-
max_len = lens.amax()
99+
if not exists(max_len):
100+
max_len = lens.amax()
103101
arange = torch.arange(max_len, device = device)
104-
return einx.less('m, b -> b m', arange, lens)
102+
return einx.less('m, ... -> ... m', arange, lens)
105103

106104
@typecheck
107105
def mean_pool_with_lens(
@@ -133,50 +131,70 @@ def mean_pool_with_lens(
133131
def repeat_consecutive_with_lens(
134132
feats: Float['b n ...'] | Bool['b n'],
135133
lens: Int['b n'],
136-
max_length: int | None = None,
137-
return_mask = False
138-
) -> Float['b m d'] | Bool['b m'] | Tuple[Float['b m d'] | Bool['b m'], Bool['b m']]:
134+
) -> Float['b m ...'] | Bool['b m']:
139135

140136
is_bool = feats.dtype == torch.bool
137+
feats = feats.float()
138+
141139
device = feats.device
142140

143-
# derive arange from the max length
141+
batch, seq, *dims = feats.shape
142+
143+
# get mask from lens
144+
145+
mask = lens_to_mask(lens)
146+
147+
# derive arange
144148

145-
total_lens = reduce(lens, 'b n -> b', 'sum')
149+
window_size = mask.shape[-1]
150+
arange = torch.arange(window_size, device = device)
146151

147-
if not exists(max_length):
148-
max_length = total_lens.amax()
152+
cumsum_len = lens.cumsum(dim = -1)
153+
offsets = F.pad(cumsum_len, (1, -1), value = 0)
154+
indices = einx.add('w, b n -> b n w', arange, offsets)
149155

150-
arange = torch.arange(max_length, device = device)
156+
# create output tensor + a sink position on the very right (index max_len)
151157

152-
# get packed atom mask from the total lengths
158+
total_lens = lens.sum(dim = -1)
159+
max_len = total_lens.amax()
153160

154-
mask = lens_to_mask(total_lens)
161+
output = torch.zeros((batch, max_len + 1, *dims), device = device)
155162

156-
lens = F.pad(lens, (1, 0), value = 0)
157-
cumsum_lens = lens.cumsum(dim = -1)
158-
left_index, right_index = cumsum_lens[:, :-1], cumsum_lens[:, 1:]
163+
indices.masked_fill_(~mask, max_len) # scatter to sink position for padding
164+
indices = rearrange(indices, 'b n w -> b (n w)')
159165

160-
# derive the mask for consecutives per feat
166+
feats = repeat(feats, 'b n ... -> b (n w) ...', w = window_size)
161167

162-
left_mask = einx.greater_equal('m, b n -> b n m', arange, left_index)
163-
right_mask = einx.less('m, b n -> b n m', arange, right_index)
168+
# scatter
164169

165-
consecutive_mask = left_mask & right_mask
170+
output = einx.set_at('b [m] ..., b nw, b nw ... -> b [m] ...', output, indices, feats)
166171

167-
# now broadcast and sum for consecutive features
172+
# remove sink
168173

169-
feats = einx.multiply('b n ..., b n m -> b n m ...', feats, consecutive_mask.float())
170-
feats = reduce(feats, 'b n m ... -> b m ...', 'sum')
174+
output = output[:, :-1]
171175

172176
if is_bool:
173-
feats = feats.bool()
177+
output = output.bool()
174178

175-
if not return_mask:
176-
return feats
179+
return output
177180

178-
mask = mask[:, :max_length]
179-
return feats, mask
181+
def repeat_pairwise_consecutive_with_lens(
182+
feats: Float['b n n dp'],
183+
lens: Int['b n']
184+
) -> Float['b m m dp']:
185+
186+
repeated_lens = repeat(lens, 'b ... -> (b r) ...', r = feats.shape[1])
187+
feats, ps = pack_one(feats, '* n dp')
188+
feats = repeat_consecutive_with_lens(feats, repeated_lens)
189+
feats = unpack_one(feats, ps, '* n dp')
190+
191+
feats = rearrange(feats, 'b i j dp -> b j i dp')
192+
repeated_lens = repeat(lens, 'b ... -> (b r) ...', r = feats.shape[1])
193+
feats, ps = pack_one(feats, '* n dp')
194+
feats = repeat_consecutive_with_lens(feats, repeated_lens)
195+
feats = unpack_one(feats, ps, '* n dp')
196+
feats = rearrange(feats, 'b j i dp -> b i j dp')
197+
return feats
180198

181199
# linear and outer sum
182200
# for single repr -> pairwise pattern throughout this architecture
@@ -1607,19 +1625,7 @@ def forward(
16071625
if is_unpacked_repr:
16081626
pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w)
16091627
else:
1610-
# todo - fix by doing a specialized fn for this
1611-
1612-
repeated_residue_atom_lens = repeat(residue_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr_cond.shape[1])
1613-
pairwise_repr_cond, ps = pack_one(pairwise_repr_cond, '* n dp')
1614-
pairwise_repr_cond = repeat_consecutive_with_lens(pairwise_repr_cond, repeated_residue_atom_lens)
1615-
pairwise_repr_cond = unpack_one(pairwise_repr_cond, ps, '* n dp')
1616-
1617-
pairwise_repr_cond = rearrange(pairwise_repr_cond, 'b i j dp -> b j i dp')
1618-
repeated_residue_atom_lens = repeat(residue_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr_cond.shape[1])
1619-
pairwise_repr_cond, ps = pack_one(pairwise_repr_cond, '* n dp')
1620-
pairwise_repr_cond = repeat_consecutive_with_lens(pairwise_repr_cond, repeated_residue_atom_lens)
1621-
pairwise_repr_cond = unpack_one(pairwise_repr_cond, ps, '* n dp')
1622-
pairwise_repr_cond = rearrange(pairwise_repr_cond, 'b j i dp -> b i j dp')
1628+
pairwise_repr_cond = repeat_pairwise_consecutive_with_lens(pairwise_repr_cond, residue_atom_lens)
16231629

16241630
atompair_feats = pairwise_repr_cond + atompair_feats
16251631

@@ -2834,7 +2840,8 @@ def forward(
28342840

28352841
# handle atom mask
28362842

2837-
atom_mask = lens_to_mask(residue_atom_lens)
2843+
total_atoms = residue_atom_lens.sum(dim = -1)
2844+
atom_mask = lens_to_mask(total_atoms)
28382845
atom_mask = atom_mask[:, :atom_seq_len]
28392846

28402847
# handle offsets for residue atom indices
@@ -2896,7 +2903,8 @@ def forward(
28962903
# pairwise mask
28972904

28982905
if self.packed_atom_repr:
2899-
mask = lens_to_mask(residue_atom_lens)
2906+
total_atoms = residue_atom_lens.sum(dim = -1)
2907+
mask = lens_to_mask(total_atoms)
29002908
mask = mask[:, :seq_len]
29012909
else:
29022910
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.30"
3+
version = "0.0.31"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ def test_mean_pool_with_lens():
3737
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
3838

3939
def test_repeat_consecutive_with_lens():
40-
seq = torch.tensor([[[1.], [2.], [4.]]])
41-
lens = torch.tensor([[3, 4, 2]]).long()
40+
seq = torch.tensor([[[1.], [2.], [4.]], [[1.], [2.], [4.]]])
41+
lens = torch.tensor([[3, 4, 2], [2, 5, 1]]).long()
4242
repeated = repeat_consecutive_with_lens(seq, lens)
43-
44-
assert torch.allclose(repeated, torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [4.], [4.]]]))
43+
assert torch.allclose(repeated, torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [4.], [4.]], [[1.], [1.], [2.], [2.], [2.], [2.], [2.], [4.], [0.]]]))
4544

4645
def test_smooth_lddt_loss():
4746
pred_coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)