Skip to content

Commit 2994804

Browse files
committed
able to return the inverse function from mean_pool_fixed_windows_with_mask for token -> atom feats
1 parent 1ce98e2 commit 2994804

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ def mean_pool_fixed_windows_with_mask(
297297
feats: Float['b m d'],
298298
mask: Bool['b m'],
299299
window_size: int,
300-
return_pooled_mask: bool = False
301-
) -> Float['b n d'] | Tuple[Float['b n'], Bool['b n']]:
300+
return_mask_and_inverse: bool = False,
301+
) -> Float['b n d'] | Tuple[Float['b n d'], Bool['b n'], Callable[[Float['b m d']], Float['b n d']]]:
302302

303303
seq_len = feats.shape[-2]
304304
assert divisible_by(seq_len, window_size)
@@ -310,11 +310,18 @@ def mean_pool_fixed_windows_with_mask(
310310

311311
avg = num / den.clamp(min = 1.)
312312

313-
if not return_pooled_mask:
313+
if not return_mask_and_inverse:
314314
return avg
315315

316316
pooled_mask = reduce(mask, 'b (n w) -> b n', 'any', w = window_size)
317-
return avg, pooled_mask
317+
318+
@typecheck
319+
def inverse_fn(pooled: Float['b n d']) -> Float['b m d']:
320+
unpooled = repeat(pooled, 'b n d -> b (n w) d', w = window_size)
321+
unpooled = einx.where('b m, b m d, -> b m d', mask, unpooled, 0.)
322+
return unpooled
323+
324+
return avg, pooled_mask, inverse_fn
318325

319326
@typecheck
320327
def batch_repeat_interleave(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.122"
3+
version = "0.2.123"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def test_mean_pool_with_mask():
8080
seq = torch.tensor([[[1.], [100.], [1.], [2.], [2.], [100.], [1.], [1.], [100.]]])
8181
mask = torch.tensor([[True, False, True, True, True, False, True, True, False]])
8282

83-
pooled = mean_pool_fixed_windows_with_mask(seq, mask, window_size = 3)
83+
pooled, _, inverse_function = mean_pool_fixed_windows_with_mask(seq, mask, window_size = 3, return_mask_and_inverse = True)
8484

85+
assert inverse_function(pooled).shape == seq.shape
8586
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
8687

8788
def test_batch_repeat_interleave():

0 commit comments

Comments
 (0)