@@ -147,10 +147,13 @@ def max_neg_value(t: Tensor):
147147 return - torch .finfo (t .dtype ).max
148148
149149def pack_one (t , pattern ):
150- return pack ([t ], pattern )
150+ packed , ps = pack ([t ], pattern )
151151
152- def unpack_one (t , ps , pattern ):
153- return unpack (t , ps , pattern )[0 ]
152+ def unpack_one (to_unpack , unpack_pattern = None ):
153+ unpacked , = unpack (to_unpack , ps , default (unpack_pattern , pattern ))
154+ return unpacked
155+
156+ return packed , unpack_one
154157
155158def exclusive_cumsum (t , dim = - 1 ):
156159 return t .cumsum (dim = dim ) - t
@@ -194,7 +197,7 @@ def atom_ref_pos_to_atompair_inputs(
194197 # Algorithm 5 - lines 2-6
195198 # allow for either batched or single
196199
197- atom_ref_pos , batch_packed_shape = pack_one (atom_ref_pos , '* m c' )
200+ atom_ref_pos , unpack_one = pack_one (atom_ref_pos , '* m c' )
198201 atom_ref_space_uid , _ = pack_one (atom_ref_space_uid , '* m' )
199202
200203 assert atom_ref_pos .shape [0 ] == atom_ref_space_uid .shape [0 ]
@@ -228,7 +231,7 @@ def atom_ref_pos_to_atompair_inputs(
228231
229232 # reconstitute optional batch dimension
230233
231- atompair_inputs = unpack_one (atompair_inputs , batch_packed_shape , '* i j dapi' )
234+ atompair_inputs = unpack_one (atompair_inputs , '* i j dapi' )
232235
233236 # return
234237
@@ -727,7 +730,7 @@ def forward(
727730 if exists (mask ):
728731 mask = repeat (mask , 'b ... -> (b repeat) ...' , repeat = batch_repeat )
729732
730- pairwise_repr , packed_shape = pack_one (pairwise_repr , '* n d' )
733+ pairwise_repr , unpack_one = pack_one (pairwise_repr , '* n d' )
731734
732735 out = self .attn (
733736 pairwise_repr ,
@@ -736,7 +739,7 @@ def forward(
736739 ** kwargs
737740 )
738741
739- out = unpack_one (out , packed_shape , '* n d' )
742+ out = unpack_one (out )
740743
741744 if self .need_transpose :
742745 out = rearrange (out , 'b j i d -> b i j d' )
@@ -1312,7 +1315,7 @@ def forward(
13121315
13131316 v = self .template_feats_to_embed_input (templates ) + pairwise_repr
13141317
1315- v , merged_batch_ps = pack_one (v , '* i j d' )
1318+ v , unpack_one = pack_one (v , '* i j d' )
13161319
13171320 has_templates = reduce (template_mask , 'b t -> b' , 'any' )
13181321
@@ -1327,7 +1330,7 @@ def forward(
13271330
13281331 u = self .final_norm (v )
13291332
1330- u = unpack_one (u , merged_batch_ps , '* i jk d' )
1333+ u = unpack_one (u )
13311334
13321335 # masked mean pool template repr
13331336
@@ -2995,9 +2998,9 @@ def forward(
29952998 pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
29962999
29973000 molecule_atom_lens = repeat (molecule_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr .shape [1 ])
2998- pairwise_repr , ps = pack_one (pairwise_repr , '* n d' )
3001+ pairwise_repr , unpack_one = pack_one (pairwise_repr , '* n d' )
29993002 pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
3000- pairwise_repr = unpack_one (pairwise_repr , ps , '* n d' )
3003+ pairwise_repr = unpack_one (pairwise_repr )
30013004
30023005 interatomic_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
30033006
@@ -3543,7 +3546,7 @@ def forward(
35433546 assert not (exists (is_molecule_mod ) ^ self .has_molecule_mod_embeds ), 'you either set `num_molecule_mods` and did not pass in `is_molecule_mod` or vice versa'
35443547
35453548 if self .has_molecule_mod_embeds :
3546- single_init , seq_packed_shape = pack_one (single_init , '* ds' )
3549+ single_init , seq_unpack_one = pack_one (single_init , '* ds' )
35473550
35483551 is_molecule_mod , _ = pack_one (is_molecule_mod , '* mods' )
35493552
@@ -3556,7 +3559,7 @@ def forward(
35563559 seq_indices = repeat (seq_indices , 'n -> n ds' , ds = single_init .shape [- 1 ])
35573560 single_init = single_init .scatter_add (0 , seq_indices , scatter_values )
35583561
3559- single_init = unpack_one (single_init , seq_packed_shape , '* ds' )
3562+ single_init = seq_unpack_one (single_init )
35603563
35613564 # relative positional encoding
35623565
0 commit comments