Skip to content

Commit 69ceb94

Browse files
committed
handle the atom parent ids
1 parent aabe7ff commit 69ceb94

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class MoleculeInput:
133133
additional_molecule_feats: Int[f'n {ADDITIONAL_MOLECULE_FEATS}']
134134
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
135135
token_bonds: Bool['n n']
136+
atom_parent_ids: Int[' m'] | None = None
136137
additional_token_feats: Float[f'n dtf'] | None = None
137138
templates: Float['t n n dt'] | None = None
138139
msa: Float['s n dm'] | None = None
@@ -287,6 +288,7 @@ def molecule_to_atom_input(
287288
additional_molecule_feats = mol_input.additional_molecule_feats,
288289
is_molecule_types = mol_input.is_molecule_types,
289290
token_bonds = mol_input.token_bonds,
291+
atom_parent_ids = mol_input.atom_parent_ids,
290292
atom_ids = atom_ids,
291293
atompair_ids = atompair_ids
292294
)
@@ -306,6 +308,7 @@ class Alphafold3Input:
306308
ligands: List[Mol | str] # can be given as smiles
307309
ds_dna: List[Int[' _'] | str]
308310
ds_rna: List[Int[' _'] | str]
311+
atom_parent_ids: Int['m'] | None = None
309312
additional_token_feats: Float[f'n dtf'] | None = None
310313
templates: Float['t n n dt'] | None = None
311314
msa: Float['s n dm'] | None = None
@@ -536,6 +539,31 @@ def alphafold3_input_to_molecule_input(
536539
molecule_ids = torch.cat(molecule_ids)
537540
molecule_ids = pad_to_len(molecule_ids, num_tokens)
538541

542+
# handle atom_parent_ids
543+
# this governs in the atom encoder / decoder, which atom attends to which
544+
# a design choice is taken so metal ions attend to each other, in case there are more than one
545+
546+
@typecheck
547+
def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
548+
atoms_per_chain = []
549+
550+
for chain in chains:
551+
num_atoms = 0
552+
for mol in chain:
553+
num_atoms += mol.GetNumAtoms()
554+
atoms_per_chain.append(num_atoms)
555+
556+
return atoms_per_chain
557+
558+
num_protein_atoms = get_num_atoms_per_chain(mol_proteins)
559+
num_ss_rna_atoms = get_num_atoms_per_chain(mol_ss_rnas)
560+
num_ss_dna_atoms = get_num_atoms_per_chain(mol_ss_dnas)
561+
num_ligand_atoms = [l.GetNumAtoms() for l in mol_ligands]
562+
563+
unflattened_atom_parent_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_atoms, *num_ss_rna_atoms, *num_ss_dna_atoms, *num_ligand_atoms, num_metal_ions])]
564+
565+
atom_parent_ids = torch.tensor(flatten(unflattened_atom_parent_ids))
566+
539567
# constructing the additional_molecule_feats
540568
# which is in turn used to derive relative positions
541569
# (todo) offer a way to precompute relative positions at data prep
@@ -546,14 +574,14 @@ def alphafold3_input_to_molecule_input(
546574
# entity_id - unique id for each biomolecule - multimeric protein, ds dna
547575
# sym_id - unique id for each chain within each biomolecule
548576

549-
num_protein_tokens = [len(protein) for protein in mol_proteins]
577+
num_protein_tokens = [len(protein) for protein in proteins]
550578
num_ss_rna_tokens = [len(rna) for rna in ss_rnas]
551579
num_ss_dna_tokens = [len(dna) for dna in ss_dnas]
580+
num_ligand_tokens = [l.GetNumAtoms() for l in mol_ligands]
552581

553-
unflattened_asym_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_tokens, *num_ss_rna_tokens, *num_ss_dna_tokens])]
582+
unflattened_asym_ids = [([asym_id] * num_tokens) for asym_id, num_tokens in enumerate([*num_protein_tokens, *num_ss_rna_tokens, *num_ss_dna_tokens, *num_ligand_tokens, num_metal_ions])]
554583

555584
asym_ids = torch.tensor(flatten(unflattened_asym_ids))
556-
asym_ids = pad_to_len(asym_ids, num_tokens)
557585

558586
additional_molecule_feats = torch.stack((
559587
molecule_ids,
@@ -581,6 +609,7 @@ def alphafold3_input_to_molecule_input(
581609
msa = i.msa,
582610
template_mask = i.template_mask,
583611
msa_mask = i.msa_mask,
612+
atom_parent_ids = atom_parent_ids,
584613
add_atom_ids = alphafold3_input.add_atom_ids,
585614
add_atompair_ids = alphafold3_input.add_atompair_ids,
586615

0 commit comments

Comments
 (0)