Skip to content

Commit cfa6c83

Browse files
committed
use repeat_interleave for building sym_id, entity_id, atom_parent_ids
1 parent 083ad92 commit cfa6c83

File tree

3 files changed

+51
-66
lines changed

3 files changed

+51
-66
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 49 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -591,9 +591,12 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
591591
num_ss_dna_atoms = get_num_atoms_per_chain(mol_ss_dnas)
592592
num_ligand_atoms = [l.GetNumAtoms() for l in mol_ligands]
593593

594-
unflattened_atom_parent_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_atoms, *num_ss_rna_atoms, *num_ss_dna_atoms, *num_ligand_atoms, num_metal_ions])]
594+
atom_counts = [*num_protein_atoms, *num_ss_rna_atoms, *num_ss_dna_atoms, *num_ligand_atoms, num_metal_ions]
595595

596-
atom_parent_ids = tensor(flatten(unflattened_atom_parent_ids))
596+
atom_parent_ids = torch.repeat_interleave(
597+
torch.arange(len(atom_counts)),
598+
tensor(atom_counts)
599+
)
597600

598601
# constructing the additional_molecule_feats
599602
# which is in turn used to derive relative positions
@@ -610,78 +613,60 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
610613
num_ss_dna_tokens = [len(dna) for dna in ss_dnas]
611614
num_ligand_tokens = [l.GetNumAtoms() for l in mol_ligands]
612615

613-
unflattened_asym_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_tokens, *num_ss_rna_tokens, *num_ss_dna_tokens, *num_ligand_tokens, num_metal_ions])]
616+
token_repeats = tensor([*num_protein_tokens, *num_ss_rna_tokens, *num_ss_dna_tokens, *num_ligand_tokens, num_metal_ions])
614617

615-
asym_ids = tensor(flatten(unflattened_asym_ids))
618+
asym_ids = torch.repeat_interleave(
619+
torch.arange(len(token_repeats)),
620+
token_repeats
621+
)
616622

617623
# entity ids
618624

619-
entity_ids = []
620-
curr_id = 0
621-
622-
def add_entity_id(length):
623-
nonlocal curr_id
624-
entity_ids.extend([curr_id] * length)
625-
curr_id += 1
626-
627-
add_entity_id(sum(num_protein_tokens))
628-
629-
for ss_rna in i.ss_rna:
630-
add_entity_id(len(ss_rna))
631-
632-
for ds_rna in i.ds_rna:
633-
add_entity_id(len(ds_rna) * 2)
634-
635-
for ss_dna in i.ss_dna:
636-
add_entity_id(len(ss_dna))
637-
638-
for ds_dna in i.ds_dna:
639-
add_entity_id(len(ds_dna) * 2)
640-
641-
for l in mol_ligands:
642-
add_entity_id(l.GetNumAtoms())
643-
644-
add_entity_id(num_metal_ions)
625+
unrepeated_entity_ids = tensor([
626+
0,
627+
*[*range(len(i.ss_rna))],
628+
*[*range(len(i.ds_rna))],
629+
*[*range(len(i.ss_dna))],
630+
*[*range(len(i.ds_dna))],
631+
*([1] * len(mol_ligands)),
632+
1
633+
]).cumsum(dim = -1)
634+
635+
entity_id_counts = [
636+
sum(num_protein_tokens),
637+
*[len(rna) for rna in i.ss_rna],
638+
*[len(rna) * 2 for rna in i.ds_rna],
639+
*[len(dna) for dna in i.ss_dna],
640+
*[len(dna) * 2 for dna in i.ds_dna],
641+
*num_ligand_tokens,
642+
num_metal_ions
643+
]
645644

646-
entity_ids = tensor(entity_ids).long()
645+
entity_ids = torch.repeat_interleave(unrepeated_entity_ids, tensor(entity_id_counts))
647646

648647
# sym_id
649648

650-
sym_ids = []
651-
curr_id = 0
652-
653-
def add_sym_id(length, reset = False):
654-
nonlocal curr_id
655-
656-
if reset:
657-
curr_id = 0
658-
659-
sym_ids.extend([curr_id] * length)
660-
curr_id += 1
661-
662-
for protein_chain_num_tokens in num_protein_tokens:
663-
add_sym_id(protein_chain_num_tokens)
664-
665-
for ss_rna in i.ss_rna:
666-
add_sym_id(len(ss_rna), reset = True)
667-
668-
for ds_rna in i.ds_rna:
669-
add_sym_id(len(ds_rna), reset = True)
670-
add_sym_id(len(ds_rna))
671-
672-
for ss_dna in i.ss_dna:
673-
add_sym_id(len(ss_dna), reset = True)
674-
675-
for ds_dna in i.ds_dna:
676-
add_sym_id(len(ds_dna), reset = True)
677-
add_sym_id(len(ds_dna))
678-
679-
for l in mol_ligands:
680-
add_sym_id(l.GetNumAtoms(), reset = True)
649+
unrepeated_sym_ids = [
650+
*[*range(len(i.proteins))],
651+
*[*range(len(i.ss_rna))],
652+
*[i for rna in i.ds_rna for i in range(2)],
653+
*[*range(len(i.ss_dna))],
654+
*[i for dna in i.ds_dna for i in range(2)],
655+
*([0] * len(mol_ligands)),
656+
0
657+
]
681658

682-
add_sym_id(num_metal_ions, reset = True)
659+
sym_id_counts = [
660+
*num_protein_tokens,
661+
*[len(rna) for rna in i.ss_rna],
662+
*flatten([((len(rna),) * 2) for rna in i.ds_rna]),
663+
*[len(dna) for dna in i.ss_dna],
664+
*flatten([((len(dna),) * 2) for dna in i.ds_dna]),
665+
*num_ligand_tokens,
666+
num_metal_ions
667+
]
683668

684-
sym_ids = tensor(sym_ids).long()
669+
sym_ids = torch.repeat_interleave(tensor(unrepeated_sym_ids), tensor(sym_id_counts))
685670

686671
# concat for all of additional_molecule_feats
687672

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.85"
3+
version = "0.1.86"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_alphafold3_input():
2727
alphafold3_input = Alphafold3Input(
2828
proteins = ['MLEICLKLVGCKSKKGLSSSSSCYLEEALQRPVASDF', 'MGKCRGLRTARKLRSHRRDQKWHDKQYKKAHLGTALKANPFGGASHAKGIVLEKVGVEAKQPNSAIRKCVRVQLIKNGKKITAFVPNDGCLNFIEENDEVLVAGFGRKGHAVGDIPGVRFKVVKVANVSLLALYKGKKERPRS'],
2929
ds_dna = ['ACGTT'],
30-
ds_rna = ['GCCAU'],
30+
ds_rna = ['GCCAU', 'CCAGU'],
3131
ss_dna = ['GCCTA'],
3232
ss_rna = ['CGCAUA'],
3333
metal_ions = ['Na', 'Na', 'Fe'],

0 commit comments

Comments
 (0)