@@ -294,7 +294,7 @@ def mean_pool_with_lens(
294294 return avg
295295
296296@typecheck
297- def repeat_consecutive_with_lens (
297+ def batch_repeat_interleave (
298298 feats : Float ['b n ...' ] | Bool ['b n ...' ] | Bool ['b n' ] | Int ['b n' ],
299299 lens : Int ['b n' ],
300300 mask_value : float | int | bool | None = None ,
@@ -2280,7 +2280,7 @@ def forward(
22802280
22812281 single_repr_cond = self .single_repr_to_atom_feat_cond (conditioned_single_repr )
22822282
2283- single_repr_cond = repeat_consecutive_with_lens (single_repr_cond , molecule_atom_lens )
2283+ single_repr_cond = batch_repeat_interleave (single_repr_cond , molecule_atom_lens )
22842284 single_repr_cond = pad_or_slice_to (single_repr_cond , length = atom_feats_cond .shape [1 ], dim = 1 )
22852285
22862286 atom_feats_cond = single_repr_cond + atom_feats_cond
@@ -2299,7 +2299,7 @@ def forward(
22992299 indices = torch .arange (seq_len , device = device )
23002300 indices = repeat (indices , 'n -> b n' , b = batch_size )
23012301
2302- indices = repeat_consecutive_with_lens (indices , molecule_atom_lens )
2302+ indices = batch_repeat_interleave (indices , molecule_atom_lens )
23032303 indices = pad_or_slice_to (indices , atom_seq_len , dim = - 1 )
23042304 indices = pad_and_window (indices , w )
23052305
@@ -2392,7 +2392,7 @@ def forward(
23922392
23932393 atom_decoder_input = self .tokens_to_atom_decoder_input_cond (tokens )
23942394
2395- atom_decoder_input = repeat_consecutive_with_lens (atom_decoder_input , molecule_atom_lens )
2395+ atom_decoder_input = batch_repeat_interleave (atom_decoder_input , molecule_atom_lens )
23962396 atom_decoder_input = pad_or_slice_to (atom_decoder_input , length = atom_feats_skip .shape [1 ], dim = 1 )
23972397
23982398 atom_decoder_input = atom_decoder_input + atom_feats_skip
@@ -2707,7 +2707,7 @@ def forward(
27072707 if exists (is_molecule_types ):
27082708 is_nucleotide_or_ligand_fields = is_molecule_types .unbind (dim = - 1 )
27092709
2710- is_nucleotide_or_ligand_fields = tuple (repeat_consecutive_with_lens (t , molecule_atom_lens ) for t in is_nucleotide_or_ligand_fields )
2710+ is_nucleotide_or_ligand_fields = tuple (batch_repeat_interleave (t , molecule_atom_lens ) for t in is_nucleotide_or_ligand_fields )
27112711 is_nucleotide_or_ligand_fields = tuple (pad_or_slice_to (t , length = align_weights .shape [- 1 ], dim = - 1 ) for t in is_nucleotide_or_ligand_fields )
27122712
27132713 _ , atom_is_dna , atom_is_rna , atom_is_ligand , _ = is_nucleotide_or_ligand_fields
@@ -3429,13 +3429,13 @@ def forward(
34293429 # handle maybe atom level resolution
34303430
34313431 if self .atom_resolution :
3432- single_repr = repeat_consecutive_with_lens (single_repr , molecule_atom_lens )
3432+ single_repr = batch_repeat_interleave (single_repr , molecule_atom_lens )
34333433
3434- pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
3434+ pairwise_repr = batch_repeat_interleave (pairwise_repr , molecule_atom_lens )
34353435
34363436 molecule_atom_lens = repeat (molecule_atom_lens , 'b ... -> (b r) ...' , r = pairwise_repr .shape [1 ])
34373437 pairwise_repr , unpack_one = pack_one (pairwise_repr , '* n d' )
3438- pairwise_repr = repeat_consecutive_with_lens (pairwise_repr , molecule_atom_lens )
3438+ pairwise_repr = batch_repeat_interleave (pairwise_repr , molecule_atom_lens )
34393439 pairwise_repr = unpack_one (pairwise_repr )
34403440
34413441 interatomic_dist = torch .cdist (pred_atom_pos , pred_atom_pos , p = 2 )
@@ -3744,8 +3744,8 @@ def forward(
37443744 valid_indices = torch .ones_like (indices ).bool ()
37453745
37463746 # valid_indices at padding position has value False
3747- indices = repeat_consecutive_with_lens (indices , molecule_atom_lens )
3748- valid_indices = repeat_consecutive_with_lens (valid_indices , molecule_atom_lens )
3747+ indices = batch_repeat_interleave (indices , molecule_atom_lens )
3748+ valid_indices = batch_repeat_interleave (valid_indices , molecule_atom_lens )
37493749
37503750 if exists (atom_mask ):
37513751 valid_indices = valid_indices * atom_mask
@@ -3811,8 +3811,8 @@ def compute_full_complex_metric(
38113811 valid_indices = torch .ones_like (indices ).bool ()
38123812
38133813 # valid_indices at padding position has value False
3814- indices = repeat_consecutive_with_lens (indices , molecule_atom_lens )
3815- valid_indices = repeat_consecutive_with_lens (valid_indices , molecule_atom_lens )
3814+ indices = batch_repeat_interleave (indices , molecule_atom_lens )
3815+ valid_indices = batch_repeat_interleave (valid_indices , molecule_atom_lens )
38163816
38173817 # broadcast is_molecule_types to atom
38183818
@@ -4265,8 +4265,8 @@ def compute_weighted_lddt(
42654265 batch_size = pred_coords .shape [0 ]
42664266
42674267 # broadcast asym_id and is_molecule_types to atom level
4268- atom_asym_id = repeat_consecutive_with_lens (asym_id , molecule_atom_lens , mask_value = - 1 )
4269- atom_is_molecule_types = repeat_consecutive_with_lens (is_molecule_types , molecule_atom_lens )
4268+ atom_asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens , mask_value = - 1 )
4269+ atom_is_molecule_types = batch_repeat_interleave (is_molecule_types , molecule_atom_lens )
42704270
42714271 weighted_lddt = torch .zeros (batch_size , device = device )
42724272
0 commit comments