Skip to content

Commit 8399c29

Browse files
committed
complete the packed atom representation and add a end2end test
1 parent 291f715 commit 8399c29

File tree

3 files changed

+207
-27
lines changed

3 files changed

+207
-27
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,13 @@ def mean_pool_with_lens(
131131

132132
@typecheck
133133
def repeat_consecutive_with_lens(
134-
feats: Float['b n d'],
134+
feats: Float['b n ...'] | Bool['b n'],
135135
lens: Int['b n'],
136136
max_length: int | None = None,
137137
return_mask = False
138-
) -> Float['b m d'] | Tuple[Float['b m d'], Bool['b m']]:
138+
) -> Float['b m d'] | Bool['b m'] | Tuple[Float['b m d'] | Bool['b m'], Bool['b m']]:
139139

140+
is_bool = feats.dtype == torch.bool
140141
device = feats.device
141142

142143
# derive arange from the max length
@@ -165,8 +166,11 @@ def repeat_consecutive_with_lens(
165166

166167
# now broadcast and sum for consecutive features
167168

168-
feats = einx.multiply('b n d, b n m -> b n m d', feats, consecutive_mask.float())
169-
feats = reduce(feats, 'b n m d -> b m d', 'sum')
169+
feats = einx.multiply('b n ..., b n m -> b n m ...', feats, consecutive_mask.float())
170+
feats = reduce(feats, 'b n m ... -> b m ...', 'sum')
171+
172+
if is_bool:
173+
feats = feats.bool()
170174

171175
if not return_mask:
172176
return feats
@@ -1373,12 +1377,24 @@ def forward(
13731377
self,
13741378
*,
13751379
atom_feats: Float['b m da'],
1376-
atom_mask: Bool['b m']
1380+
atom_mask: Bool['b m'],
1381+
residue_atom_lens: Int['b n'] | None = None
13771382
) -> Float['b n ds']:
1383+
13781384
w = self.atoms_per_window
1385+
is_unpacked_repr = exists(w)
1386+
1387+
assert is_unpacked_repr ^ exists(residue_atom_lens), '`residue_atom_lens` must be passed in if using packed_atom_repr (atoms_per_window is None)'
13791388

13801389
atom_feats = self.proj(atom_feats)
13811390

1391+
# packed atom representation
1392+
1393+
if exists(residue_atom_lens):
1394+
tokens = mean_pool_with_lens(atom_feats, residue_atom_lens)
1395+
return tokens
1396+
1397+
# unpacked atom representation
13821398
# masked mean pool the atom feats for each residue, for the token transformer
13831399
# this is basically a simple 2-level hierarchical transformer
13841400

@@ -1541,13 +1557,18 @@ def forward(
15411557
single_inputs_repr: Float['b n dsi'],
15421558
pairwise_trunk: Float['b n n dpt'],
15431559
pairwise_rel_pos_feats: Float['b n n dpr'],
1560+
residue_atom_lens: Int['b n'] | None = None
15441561
):
15451562
w = self.atoms_per_window
1563+
is_unpacked_repr = exists(w)
1564+
1565+
assert is_unpacked_repr ^ exists(residue_atom_lens)
15461566

15471567
# in the paper, it seems they pack the atom feats
15481568
# but in this impl, will just use windows for simplicity when communicating between atom and residue resolutions. bit less efficient
15491569

1550-
assert divisible_by(noised_atom_pos.shape[-2], w)
1570+
if is_unpacked_repr:
1571+
assert divisible_by(noised_atom_pos.shape[-2], w)
15511572

15521573
conditioned_single_repr = self.single_conditioner(
15531574
times = times,
@@ -1572,14 +1593,20 @@ def forward(
15721593

15731594
single_repr_cond = self.single_repr_to_atom_feat_cond(conditioned_single_repr)
15741595

1575-
single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w)
1596+
if is_unpacked_repr:
1597+
single_repr_cond = repeat(single_repr_cond, 'b n ds -> b (n w) ds', w = w)
1598+
else:
1599+
single_repr_cond = repeat_consecutive_with_lens(single_repr_cond, residue_atom_lens)
1600+
15761601
atom_feats_cond = single_repr_cond + atom_feats_cond
15771602

15781603
# condition atompair feats with pairwise repr
15791604

15801605
pairwise_repr_cond = self.pairwise_repr_to_atompair_feat_cond(conditioned_pairwise_repr)
1581-
pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w)
1582-
atompair_feats = pairwise_repr_cond + atompair_feats
1606+
1607+
if is_unpacked_repr:
1608+
pairwise_repr_cond = repeat(pairwise_repr_cond, 'b i j dp -> b (i w1) (j w2) dp', w1 = w, w2 = w)
1609+
atompair_feats = pairwise_repr_cond + atompair_feats
15831610

15841611
# condition atompair feats further with single atom repr
15851612

@@ -1603,8 +1630,10 @@ def forward(
16031630

16041631
tokens = self.atom_feats_to_pooled_token(
16051632
atom_feats = atom_feats,
1606-
atom_mask = atom_mask
1633+
atom_mask = atom_mask,
1634+
residue_atom_lens = residue_atom_lens
16071635
)
1636+
16081637
# token transformer
16091638

16101639
tokens = self.cond_tokens_with_cond_single(conditioned_single_repr) + tokens
@@ -1621,7 +1650,11 @@ def forward(
16211650
# atom decoder
16221651

16231652
atom_decoder_input = self.tokens_to_atom_decoder_input_cond(tokens)
1624-
atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w)
1653+
1654+
if is_unpacked_repr:
1655+
atom_decoder_input = repeat(atom_decoder_input, 'b n da -> b (n w) da', w = w)
1656+
else:
1657+
atom_decoder_input = repeat_consecutive_with_lens(atom_decoder_input, residue_atom_lens)
16251658

16261659
atom_decoder_input = atom_decoder_input + atom_feats_skip
16271660

@@ -1771,7 +1804,7 @@ def sample_schedule(self, num_sample_steps = None):
17711804
@torch.no_grad()
17721805
def sample(
17731806
self,
1774-
atom_mask: Bool['b m'],
1807+
atom_mask: Bool['b m'] | None = None,
17751808
num_sample_steps = None,
17761809
clamp = True,
17771810
**network_condition_kwargs
@@ -1847,6 +1880,7 @@ def forward(
18471880
pairwise_trunk: Float['b n n dpt'],
18481881
pairwise_rel_pos_feats: Float['b n n dpr'],
18491882
return_denoised_pos = False,
1883+
residue_atom_lens: Int['b n'] | None = None,
18501884
additional_residue_feats: Float['b n 10'] | None = None,
18511885
add_smooth_lddt_loss = False,
18521886
add_bond_loss = False,
@@ -1878,6 +1912,7 @@ def forward(
18781912
single_inputs_repr = single_inputs_repr,
18791913
pairwise_trunk = pairwise_trunk,
18801914
pairwise_rel_pos_feats = pairwise_rel_pos_feats,
1915+
residue_atom_lens = residue_atom_lens
18811916
)
18821917
)
18831918

@@ -1890,9 +1925,14 @@ def forward(
18901925

18911926
if exists(additional_residue_feats):
18921927
w = self.net.atoms_per_window
1928+
is_unpacked_repr = exists(w)
18931929

1894-
is_nucleotide_or_ligand_fields = additional_residue_feats[..., 7:] != 0.
1895-
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields.unbind(dim = -1))
1930+
is_nucleotide_or_ligand_fields = (additional_residue_feats[..., 7:] != 0.).unbind(dim = -1)
1931+
1932+
if is_unpacked_repr:
1933+
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat(t != 0., 'b n -> b (n w)', w = w) for t in is_nucleotide_or_ligand_fields)
1934+
else:
1935+
atom_is_dna, atom_is_rna, atom_is_ligand = tuple(repeat_consecutive_with_lens(t, residue_atom_lens) for t in is_nucleotide_or_ligand_fields)
18961936

18971937
# section 3.7.1 equation 4
18981938

@@ -2315,6 +2355,8 @@ def forward(
23152355
atom_mask: Bool['b m'],
23162356
atompair_feats: Float['b m m dap'],
23172357
additional_residue_feats: Float['b n rf'],
2358+
residue_atom_lens: Int['b n'] | None = None,
2359+
23182360
) -> EmbeddedInputs:
23192361

23202362
assert additional_residue_feats.shape[-1] == self.dim_additional_residue_feats
@@ -2336,7 +2378,8 @@ def forward(
23362378

23372379
single_inputs = self.atom_feats_to_pooled_token(
23382380
atom_feats = atom_feats,
2339-
atom_mask = atom_mask
2381+
atom_mask = atom_mask,
2382+
residue_atom_lens = residue_atom_lens
23402383
)
23412384

23422385
single_inputs = torch.cat((single_inputs, additional_residue_feats), dim = -1)
@@ -2527,6 +2570,7 @@ def __init__(
25272570
num_pae_bins = 64,
25282571
sigma_data = 16,
25292572
diffusion_num_augmentations = 4,
2573+
packed_atom_repr = False,
25302574
loss_confidence_weight = 1e-4,
25312575
loss_distogram_weight = 1e-2,
25322576
loss_diffusion_weight = 4.,
@@ -2593,6 +2637,15 @@ def __init__(
25932637
):
25942638
super().__init__()
25952639

2640+
# whether a packed atom representation is being used
2641+
2642+
self.packed_atom_repr = packed_atom_repr
2643+
2644+
# atoms per window if using unpacked representation
2645+
2646+
if packed_atom_repr:
2647+
atoms_per_window = None
2648+
25962649
self.atoms_per_window = atoms_per_window
25972650

25982651
# augmentation
@@ -2732,9 +2785,10 @@ def forward(
27322785
self,
27332786
*,
27342787
atom_inputs: Float['b m dai'],
2735-
atom_mask: Bool['b m'],
27362788
atompair_feats: Float['b m m dap'],
27372789
additional_residue_feats: Float['b n 10'],
2790+
residue_atom_lens: Int['b n'] | None = None,
2791+
atom_mask: Bool['b m'] | None = None,
27382792
token_bond: Bool['b n n'] | None = None,
27392793
msa: Float['b s n d'] | None = None,
27402794
msa_mask: Bool['b s'] | None = None,
@@ -2754,13 +2808,33 @@ def forward(
27542808
return_loss_breakdown = False
27552809
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
27562810

2757-
# get atom sequence length and residue sequence length
2758-
2759-
w = self.atoms_per_window
27602811
atom_seq_len = atom_inputs.shape[-2]
27612812

2762-
assert divisible_by(atom_seq_len, w)
2763-
seq_len = atom_inputs.shape[-2] // w
2813+
# determine whether using packed or unpacked atom rep
2814+
2815+
assert exists(residue_atom_lens) ^ exists(atom_mask), 'either atom_lens or atom_mask must be given depending on whether packed_atom_repr kwarg is True or False'
2816+
2817+
if exists(residue_atom_lens):
2818+
assert self.packed_atom_repr, '`packed_atom_repr` kwarg on Alphafold3 must be True when passing in `atom_lens`'
2819+
2820+
# handle atom mask
2821+
2822+
atom_mask = lens_to_mask(residue_atom_lens)
2823+
atom_mask = atom_mask[:, :atom_seq_len]
2824+
2825+
# handle offsets for residue atom indices
2826+
2827+
if exists(residue_atom_indices):
2828+
residue_atom_indices += F.pad(residue_atom_lens, (-1, 1), value = 0)
2829+
2830+
# get atom sequence length and residue sequence length depending on whether using packed atomic seq
2831+
2832+
if self.packed_atom_repr:
2833+
seq_len = residue_atom_lens.shape[-1]
2834+
else:
2835+
w = self.atoms_per_window
2836+
assert divisible_by(atom_seq_len, w)
2837+
seq_len = atom_inputs.shape[-2] // w
27642838

27652839
# embed inputs
27662840

@@ -2774,7 +2848,8 @@ def forward(
27742848
atom_inputs = atom_inputs,
27752849
atom_mask = atom_mask,
27762850
atompair_feats = atompair_feats,
2777-
additional_residue_feats = additional_residue_feats
2851+
additional_residue_feats = additional_residue_feats,
2852+
residue_atom_lens = residue_atom_lens
27782853
)
27792854

27802855
# relative positional encoding
@@ -2805,7 +2880,12 @@ def forward(
28052880

28062881
# pairwise mask
28072882

2808-
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
2883+
if self.packed_atom_repr:
2884+
mask = lens_to_mask(residue_atom_lens)
2885+
mask = mask[:, :seq_len]
2886+
else:
2887+
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
2888+
28092889
pairwise_mask = einx.logical_and('b i, b j -> b i j', mask, mask)
28102890

28112891
# init recycled single and pairwise
@@ -2889,7 +2969,8 @@ def forward(
28892969
single_trunk_repr = single,
28902970
single_inputs_repr = single_inputs,
28912971
pairwise_trunk = pairwise,
2892-
pairwise_rel_pos_feats = relative_position_encoding
2972+
pairwise_rel_pos_feats = relative_position_encoding,
2973+
residue_atom_lens = residue_atom_lens
28932974
)
28942975

28952976
# losses default to 0
@@ -2903,7 +2984,12 @@ def forward(
29032984
# distogram head
29042985

29052986
if not exists(distance_labels) and atom_pos_given and exists(residue_atom_indices):
2906-
residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices)
2987+
2988+
if self.packed_atom_repr:
2989+
residue_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, residue_atom_indices)
2990+
else:
2991+
residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices)
2992+
29072993
residue_dist = torch.cdist(residue_pos, residue_pos, p = 2)
29082994
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', residue_dist, self.distance_bins).abs()
29092995
distance_labels = dist_from_dist_bins.argmin(dim = -1)
@@ -2938,6 +3024,7 @@ def forward(
29383024
relative_position_encoding,
29393025
additional_residue_feats,
29403026
residue_atom_indices,
3027+
residue_atom_lens,
29413028
pae_labels,
29423029
pde_labels,
29433030
plddt_labels,
@@ -2958,6 +3045,7 @@ def forward(
29583045
relative_position_encoding,
29593046
additional_residue_feats,
29603047
residue_atom_indices,
3048+
residue_atom_lens,
29613049
pae_labels,
29623050
pde_labels,
29633051
plddt_labels,
@@ -2980,6 +3068,7 @@ def forward(
29803068
single_inputs_repr = single_inputs,
29813069
pairwise_trunk = pairwise,
29823070
pairwise_rel_pos_feats = relative_position_encoding,
3071+
residue_atom_lens = residue_atom_lens,
29833072
return_denoised_pos = True,
29843073
)
29853074

@@ -2990,7 +3079,10 @@ def forward(
29903079

29913080
if calc_diffusion_loss and should_call_confidence_head:
29923081

2993-
pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
3082+
if self.packed_atom_repr:
3083+
pred_atom_pos = einx.get_at('b [m] c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
3084+
else:
3085+
pred_atom_pos = einx.get_at('b (n [w]) c, b n -> b n c', denoised_atom_pos, residue_atom_indices)
29943086

29953087
logits = self.confidence_head(
29963088
single_repr = single,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.25"
3+
version = "0.0.26"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)