Skip to content

Commit 04fe0e0

Browse files
authored
rename fn (#140)
1 parent 212d7c7 commit 04fe0e0

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def mean_pool_with_lens(
294294
return avg
295295

296296
@typecheck
297-
def repeat_consecutive_with_lens(
297+
def batch_repeat_interleave(
298298
feats: Float['b n ...'] | Bool['b n ...'] | Bool['b n'] | Int['b n'],
299299
lens: Int['b n'],
300300
mask_value: float | int | bool | None = None,
@@ -2280,7 +2280,7 @@ def forward(
22802280

22812281
single_repr_cond = self.single_repr_to_atom_feat_cond(conditioned_single_repr)
22822282

2283-
single_repr_cond = repeat_consecutive_with_lens(single_repr_cond, molecule_atom_lens)
2283+
single_repr_cond = batch_repeat_interleave(single_repr_cond, molecule_atom_lens)
22842284
single_repr_cond = pad_or_slice_to(single_repr_cond, length = atom_feats_cond.shape[1], dim = 1)
22852285

22862286
atom_feats_cond = single_repr_cond + atom_feats_cond
@@ -2299,7 +2299,7 @@ def forward(
22992299
indices = torch.arange(seq_len, device = device)
23002300
indices = repeat(indices, 'n -> b n', b = batch_size)
23012301

2302-
indices = repeat_consecutive_with_lens(indices, molecule_atom_lens)
2302+
indices = batch_repeat_interleave(indices, molecule_atom_lens)
23032303
indices = pad_or_slice_to(indices, atom_seq_len, dim = -1)
23042304
indices = pad_and_window(indices, w)
23052305

@@ -2392,7 +2392,7 @@ def forward(
23922392

23932393
atom_decoder_input = self.tokens_to_atom_decoder_input_cond(tokens)
23942394

2395-
atom_decoder_input = repeat_consecutive_with_lens(atom_decoder_input, molecule_atom_lens)
2395+
atom_decoder_input = batch_repeat_interleave(atom_decoder_input, molecule_atom_lens)
23962396
atom_decoder_input = pad_or_slice_to(atom_decoder_input, length = atom_feats_skip.shape[1], dim = 1)
23972397

23982398
atom_decoder_input = atom_decoder_input + atom_feats_skip
@@ -2707,7 +2707,7 @@ def forward(
27072707
if exists(is_molecule_types):
27082708
is_nucleotide_or_ligand_fields = is_molecule_types.unbind(dim = -1)
27092709

2710-
is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, molecule_atom_lens) for t in is_nucleotide_or_ligand_fields)
2710+
is_nucleotide_or_ligand_fields = tuple(batch_repeat_interleave(t, molecule_atom_lens) for t in is_nucleotide_or_ligand_fields)
27112711
is_nucleotide_or_ligand_fields = tuple(pad_or_slice_to(t, length = align_weights.shape[-1], dim = -1) for t in is_nucleotide_or_ligand_fields)
27122712

27132713
_, atom_is_dna, atom_is_rna, atom_is_ligand, _ = is_nucleotide_or_ligand_fields
@@ -3429,13 +3429,13 @@ def forward(
34293429
# handle maybe atom level resolution
34303430

34313431
if self.atom_resolution:
3432-
single_repr = repeat_consecutive_with_lens(single_repr, molecule_atom_lens)
3432+
single_repr = batch_repeat_interleave(single_repr, molecule_atom_lens)
34333433

3434-
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
3434+
pairwise_repr = batch_repeat_interleave(pairwise_repr, molecule_atom_lens)
34353435

34363436
molecule_atom_lens = repeat(molecule_atom_lens, 'b ... -> (b r) ...', r = pairwise_repr.shape[1])
34373437
pairwise_repr, unpack_one = pack_one(pairwise_repr, '* n d')
3438-
pairwise_repr = repeat_consecutive_with_lens(pairwise_repr, molecule_atom_lens)
3438+
pairwise_repr = batch_repeat_interleave(pairwise_repr, molecule_atom_lens)
34393439
pairwise_repr = unpack_one(pairwise_repr)
34403440

34413441
interatomic_dist = torch.cdist(pred_atom_pos, pred_atom_pos, p = 2)
@@ -3744,8 +3744,8 @@ def forward(
37443744
valid_indices = torch.ones_like(indices).bool()
37453745

37463746
# valid_indices at padding position has value False
3747-
indices = repeat_consecutive_with_lens(indices, molecule_atom_lens)
3748-
valid_indices = repeat_consecutive_with_lens(valid_indices, molecule_atom_lens)
3747+
indices = batch_repeat_interleave(indices, molecule_atom_lens)
3748+
valid_indices = batch_repeat_interleave(valid_indices, molecule_atom_lens)
37493749

37503750
if exists(atom_mask):
37513751
valid_indices = valid_indices * atom_mask
@@ -3811,8 +3811,8 @@ def compute_full_complex_metric(
38113811
valid_indices = torch.ones_like(indices).bool()
38123812

38133813
# valid_indices at padding position has value False
3814-
indices = repeat_consecutive_with_lens(indices, molecule_atom_lens)
3815-
valid_indices = repeat_consecutive_with_lens(valid_indices, molecule_atom_lens)
3814+
indices = batch_repeat_interleave(indices, molecule_atom_lens)
3815+
valid_indices = batch_repeat_interleave(valid_indices, molecule_atom_lens)
38163816

38173817
# broadcast is_molecule_types to atom
38183818

@@ -4265,8 +4265,8 @@ def compute_weighted_lddt(
42654265
batch_size = pred_coords.shape[0]
42664266

42674267
# broadcast asym_id and is_molecule_types to atom level
4268-
atom_asym_id = repeat_consecutive_with_lens(asym_id, molecule_atom_lens, mask_value=-1)
4269-
atom_is_molecule_types = repeat_consecutive_with_lens(is_molecule_types, molecule_atom_lens)
4268+
atom_asym_id = batch_repeat_interleave(asym_id, molecule_atom_lens, mask_value=-1)
4269+
atom_is_molecule_types = batch_repeat_interleave(is_molecule_types, molecule_atom_lens)
42704270

42714271
weighted_lddt = torch.zeros(batch_size, device=device)
42724272

alphafold3_pytorch/utils/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def mean_pool_with_lens(
374374

375375

376376
@typecheck
377-
def repeat_consecutive_with_lens(
377+
def batch_repeat_interleave(
378378
feats: Float["b n ..."] | Bool["b n"] | Int["b n"], # type: ignore
379379
lens: Int["b n"], # type: ignore
380380
) -> Float["b m ..."] | Bool["b m"] | Int["b m"]: # type: ignore

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.101"
3+
version = "0.2.102"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
6-
{ name = "Phil Wang", email = "[email protected]" }
6+
{ name = "Phil Wang", email = "[email protected]" },
7+
{ name = "Alex Morehead", email = "[email protected]"}
78
]
89
readme = "README.md"
910
requires-python = ">= 3.8"

tests/test_af3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
from alphafold3_pytorch.alphafold3 import (
4545
mean_pool_with_lens,
46-
repeat_consecutive_with_lens,
46+
batch_repeat_interleave,
4747
full_pairwise_repr_to_windowed,
4848
get_cid_molecule_type,
4949
)
@@ -75,10 +75,10 @@ def test_mean_pool_with_lens():
7575

7676
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
7777

78-
def test_repeat_consecutive_with_lens():
78+
def test_batch_repeat_interleave():
7979
seq = torch.tensor([[[1.], [2.], [4.]], [[1.], [2.], [4.]]])
8080
lens = torch.tensor([[3, 4, 2], [2, 5, 1]]).long()
81-
repeated = repeat_consecutive_with_lens(seq, lens)
81+
repeated = batch_repeat_interleave(seq, lens)
8282
assert torch.allclose(repeated, torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [4.], [4.]], [[1.], [1.], [2.], [2.], [2.], [2.], [2.], [4.], [0.]]]))
8383

8484
def test_smooth_lddt_loss():

0 commit comments

Comments
 (0)