@@ -327,7 +327,7 @@ def inverse_fn(pooled: Float['b n d']) -> Float['b m d']:
327327def batch_repeat_interleave (
328328 feats : Float ['b n ...' ] | Bool ['b n ...' ] | Bool ['b n' ] | Int ['b n' ],
329329 lens : Int ['b n' ],
330- mask_value : float | int | bool | None = None ,
330+ output_padding_value : float | int | bool | None = None , # this value determines what the output padding value will be
331331) -> Float ['b m ...' ] | Bool ['b m ...' ] | Bool ['b m' ] | Int ['b m' ]:
332332
333333 device , dtype = feats .device , feats .dtype
@@ -348,7 +348,7 @@ def batch_repeat_interleave(
348348
349349 # create output tensor + a sink position on the very right (index max_len)
350350
351- total_lens = lens .sum (dim = - 1 )
351+ total_lens = lens .clamp ( min = 0 ). sum (dim = - 1 )
352352 output_mask = lens_to_mask (total_lens )
353353
354354 max_len = total_lens .amax ()
@@ -380,13 +380,13 @@ def batch_repeat_interleave(
380380 output = feats .gather (1 , output_indices )
381381 output = unpack_one (output )
382382
383- # final mask
383+ # set output padding value
384384
385- mask_value = default (mask_value , False if dtype == torch .bool else 0 )
385+ output_padding_value = default (output_padding_value , False if dtype == torch .bool else 0 )
386386
387387 output = einx .where (
388388 'b n, b n ..., -> b n ...' ,
389- output_mask , output , mask_value
389+ output_mask , output , output_padding_value
390390 )
391391
392392 return output
@@ -4289,7 +4289,7 @@ def compute_weighted_lddt(
42894289 batch_size = pred_coords .shape [0 ]
42904290
42914291 # broadcast asym_id and is_molecule_types to atom level
4292- atom_asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens , mask_value = - 1 )
4292+ atom_asym_id = batch_repeat_interleave (asym_id , molecule_atom_lens , output_padding_value = - 1 )
42934293 atom_is_molecule_types = batch_repeat_interleave (is_molecule_types , molecule_atom_lens )
42944294
42954295 weighted_lddt = torch .zeros (batch_size , device = device )
0 commit comments