Skip to content

Commit 6570963

Browse files
committed
make sure lens passed into batch_repeat_interleave is never negative, addressing #158
1 parent 2994804 commit 6570963

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def inverse_fn(pooled: Float['b n d']) -> Float['b m d']:
327327
def batch_repeat_interleave(
328328
feats: Float['b n ...'] | Bool['b n ...'] | Bool['b n'] | Int['b n'],
329329
lens: Int['b n'],
330-
mask_value: float | int | bool | None = None,
330+
output_padding_value: float | int | bool | None = None, # this value determines what the output padding value will be
331331
) -> Float['b m ...'] | Bool['b m ...'] | Bool['b m'] | Int['b m']:
332332

333333
device, dtype = feats.device, feats.dtype
@@ -348,7 +348,7 @@ def batch_repeat_interleave(
348348

349349
# create output tensor + a sink position on the very right (index max_len)
350350

351-
total_lens = lens.sum(dim = -1)
351+
total_lens = lens.clamp(min = 0).sum(dim = -1)
352352
output_mask = lens_to_mask(total_lens)
353353

354354
max_len = total_lens.amax()
@@ -380,13 +380,13 @@ def batch_repeat_interleave(
380380
output = feats.gather(1, output_indices)
381381
output = unpack_one(output)
382382

383-
# final mask
383+
# set output padding value
384384

385-
mask_value = default(mask_value, False if dtype == torch.bool else 0)
385+
output_padding_value = default(output_padding_value, False if dtype == torch.bool else 0)
386386

387387
output = einx.where(
388388
'b n, b n ..., -> b n ...',
389-
output_mask, output, mask_value
389+
output_mask, output, output_padding_value
390390
)
391391

392392
return output
@@ -4289,7 +4289,7 @@ def compute_weighted_lddt(
42894289
batch_size = pred_coords.shape[0]
42904290

42914291
# broadcast asym_id and is_molecule_types to atom level
4292-
atom_asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens, mask_value=-1)
4292+
atom_asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens, output_padding_value=-1)
42934293
atom_is_molecule_types = batch_repeat_interleave(is_molecule_types, molecule_atom_lens)
42944294

42954295
weighted_lddt = torch.zeros(batch_size, device=device)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.123"
3+
version = "0.2.124"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)