Skip to content

Commit 154ac2a

Browse files
committed
have pack return the inverse fn with packed shape closured
1 parent 3598f6e commit 154ac2a

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,13 @@ def max_neg_value(t: Tensor):
147147
return -torch.finfo(t.dtype).max
148148

149149
def pack_one(t, pattern):
150-
return pack([t], pattern)
150+
packed, ps = pack([t], pattern)
151151

152-
def unpack_one(t, ps, pattern):
153-
return unpack(t, ps, pattern)[0]
152+
def unpack_one(to_unpack, unpack_pattern = None):
153+
unpacked, = unpack(to_unpack, ps, default(unpack_pattern, pattern))
154+
return unpacked
155+
156+
return packed, unpack_one
154157

155158
def exclusive_cumsum(t, dim = -1):
156159
return t.cumsum(dim = dim) - t
@@ -194,7 +197,7 @@ def atom_ref_pos_to_atompair_inputs(
194197
# Algorithm 5 - lines 2-6
195198
# allow for either batched or single
196199

197-
atom_ref_pos, batch_packed_shape = pack_one(atom_ref_pos, '* m c')
200+
atom_ref_pos, unpack_one = pack_one(atom_ref_pos, '* m c')
198201
atom_ref_space_uid, _ = pack_one(atom_ref_space_uid, '* m')
199202

200203
assert atom_ref_pos.shape[0] == atom_ref_space_uid.shape[0]
@@ -228,7 +231,7 @@ def atom_ref_pos_to_atompair_inputs(
228231

229232
# reconstitute optional batch dimension
230233

231-
atompair_inputs = unpack_one(atompair_inputs, batch_packed_shape, '* i j dapi')
234+
atompair_inputs = unpack_one(atompair_inputs, '* i j dapi')
232235

233236
# return
234237

@@ -727,7 +730,7 @@ def forward(
727730
if exists(mask):
728731
mask = repeat(mask, 'b ... -> (b repeat) ...', repeat = batch_repeat)
729732

730-
pairwise_repr, packed_shape = pack_one(pairwise_repr, '* n d')
733+
pairwise_repr, unpack_one = pack_one(pairwise_repr, '* n d')
731734

732735
out = self.attn(
733736
pairwise_repr,
@@ -736,7 +739,7 @@ def forward(
736739
**kwargs
737740
)
738741

739-
out = unpack_one(out, packed_shape, '* n d')
742+
out = unpack_one(out)
740743

741744
if self.need_transpose:
742745
out = rearrange(out, 'b j i d -> b i j d')
@@ -1312,7 +1315,7 @@ def forward(
13121315

13131316
v = self.template_feats_to_embed_input(templates) + pairwise_repr
13141317

1315-
v, merged_batch_ps = pack_one(v, '* i j d')
1318+
v, unpack_one = pack_one(v, '* i j d')
13161319

13171320
has_templates = reduce(template_mask, 'b t -> b', 'any')
13181321

@@ -1327,7 +1330,7 @@ def forward(
13271330

13281331
u = self.final_norm(v)
13291332

1330-
u = unpack_one(u, merged_batch_ps, '* i jk d')
1333+
u = unpack_one(u)
13311334

13321335
# masked mean pool template repr
13331336

@@ -2995,9 +2998,9 @@ def forward(
29952998
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
29962999

29973000
molecule_atom_lens = repeat(molecule_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr.shape[1])
2998-
pairwise_repr, ps = pack_one(pairwise_repr, '* n d')
3001+
pairwise_repr, unpack_one = pack_one(pairwise_repr, '* n d')
29993002
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
3000-
pairwise_repr = unpack_one(pairwise_repr, ps, '* n d')
3003+
pairwise_repr = unpack_one(pairwise_repr)
30013004

30023005
interatomic_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
30033006

@@ -3543,7 +3546,7 @@ def forward(
35433546
assert not (exists(is_molecule_mod) ^ self.has_molecule_mod_embeds), 'you either set `num_molecule_mods` and did not pass in `is_molecule_mod` or vice versa'
35443547

35453548
if self.has_molecule_mod_embeds:
3546-
single_init, seq_packed_shape = pack_one(single_init, '* ds')
3549+
single_init, seq_unpack_one = pack_one(single_init, '* ds')
35473550

35483551
is_molecule_mod, _ = pack_one(is_molecule_mod, '* mods')
35493552

@@ -3556,7 +3559,7 @@ def forward(
35563559
seq_indices = repeat(seq_indices, 'n -> n ds', ds = single_init.shape[-1])
35573560
single_init = single_init.scatter_add(0, seq_indices, scatter_values)
35583561

3559-
single_init = unpack_one(single_init, seq_packed_shape, '* ds')
3562+
single_init = seq_unpack_one(single_init)
35603563

35613564
# relative positional encoding
35623565

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.17"
3+
version = "0.2.18"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)