@@ -133,6 +133,7 @@ class MoleculeInput:
133133 additional_molecule_feats : Int [f'n { ADDITIONAL_MOLECULE_FEATS } ' ]
134134 is_molecule_types : Bool [f'n { IS_MOLECULE_TYPES } ' ]
135135 token_bonds : Bool ['n n' ]
136+ atom_parent_ids : Int [' m' ] | None = None
136137 additional_token_feats : Float [f'n dtf' ] | None = None
137138 templates : Float ['t n n dt' ] | None = None
138139 msa : Float ['s n dm' ] | None = None
@@ -287,6 +288,7 @@ def molecule_to_atom_input(
287288 additional_molecule_feats = mol_input .additional_molecule_feats ,
288289 is_molecule_types = mol_input .is_molecule_types ,
289290 token_bonds = mol_input .token_bonds ,
291+ atom_parent_ids = mol_input .atom_parent_ids ,
290292 atom_ids = atom_ids ,
291293 atompair_ids = atompair_ids
292294 )
@@ -306,6 +308,7 @@ class Alphafold3Input:
306308 ligands : List [Mol | str ] # can be given as smiles
307309 ds_dna : List [Int [' _' ] | str ]
308310 ds_rna : List [Int [' _' ] | str ]
311+ atom_parent_ids : Int ['m' ] | None = None
309312 additional_token_feats : Float [f'n dtf' ] | None = None
310313 templates : Float ['t n n dt' ] | None = None
311314 msa : Float ['s n dm' ] | None = None
@@ -536,6 +539,31 @@ def alphafold3_input_to_molecule_input(
536539 molecule_ids = torch .cat (molecule_ids )
537540 molecule_ids = pad_to_len (molecule_ids , num_tokens )
538541
542+ # handle atom_parent_ids
543+ # this governs in the atom encoder / decoder, which atom attends to which
544+ # a design choice is taken so metal ions attend to each other, in case there are more than one
545+
546+ @typecheck
547+ def get_num_atoms_per_chain (chains : List [List [Mol ]]) -> List [int ]:
548+ atoms_per_chain = []
549+
550+ for chain in chains :
551+ num_atoms = 0
552+ for mol in chain :
553+ num_atoms += mol .GetNumAtoms ()
554+ atoms_per_chain .append (num_atoms )
555+
556+ return atoms_per_chain
557+
558+ num_protein_atoms = get_num_atoms_per_chain (mol_proteins )
559+ num_ss_rna_atoms = get_num_atoms_per_chain (mol_ss_rnas )
560+ num_ss_dna_atoms = get_num_atoms_per_chain (mol_ss_dnas )
561+ num_ligand_atoms = [l .GetNumAtoms () for l in mol_ligands ]
562+
563+ unflattened_atom_parent_ids = [([asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([* num_protein_atoms , * num_ss_rna_atoms , * num_ss_dna_atoms , * num_ligand_atoms , num_metal_ions ])]
564+
565+ atom_parent_ids = torch .tensor (flatten (unflattened_atom_parent_ids ))
566+
539567 # constructing the additional_molecule_feats
540568 # which is in turn used to derive relative positions
541569 # (todo) offer a way to precompute relative positions at data prep
@@ -546,14 +574,14 @@ def alphafold3_input_to_molecule_input(
546574 # entity_id - unique id for each biomolecule - multimeric protein, ds dna
547575 # sym_id - unique id for each chain within each biomolecule
548576
549- num_protein_tokens = [len (protein ) for protein in mol_proteins ]
577+ num_protein_tokens = [len (protein ) for protein in proteins ]
550578 num_ss_rna_tokens = [len (rna ) for rna in ss_rnas ]
551579 num_ss_dna_tokens = [len (dna ) for dna in ss_dnas ]
580+ num_ligand_tokens = [l .GetNumAtoms () for l in mol_ligands ]
552581
553- unflattened_asym_ids = [([asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([* num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens ])]
582+ unflattened_asym_ids = [([asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([* num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens , * num_ligand_tokens , num_metal_ions ])]
554583
555584 asym_ids = torch .tensor (flatten (unflattened_asym_ids ))
556- asym_ids = pad_to_len (asym_ids , num_tokens )
557585
558586 additional_molecule_feats = torch .stack ((
559587 molecule_ids ,
@@ -581,6 +609,7 @@ def alphafold3_input_to_molecule_input(
581609 msa = i .msa ,
582610 template_mask = i .template_mask ,
583611 msa_mask = i .msa_mask ,
612+ atom_parent_ids = atom_parent_ids ,
584613 add_atom_ids = alphafold3_input .add_atom_ids ,
585614 add_atompair_ids = alphafold3_input .add_atompair_ids ,
586615
0 commit comments