@@ -102,8 +102,10 @@ def flatten(arr):
102102def exclusive_cumsum (t ):
103103 return t .cumsum (dim = - 1 ) - t
104104
105- def pad_to_len (t , length , value = 0 ):
106- return F .pad (t , (0 , max (0 , length - t .shape [- 1 ])), value = value )
105+ def pad_to_len (t , length , value = 0 , dim = - 1 ):
106+ assert dim < 0
107+ zeros = (0 , 0 ) * (- dim - 1 )
108+ return F .pad (t , (* zeros , 0 , max (0 , length - t .shape [- 1 ])), value = value )
107109
108110def compose (* fns : Callable ):
109111 # for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
@@ -205,6 +207,7 @@ class MoleculeInput:
205207 molecule_ids : Int [' n' ]
206208 additional_molecule_feats : Int [f'n { ADDITIONAL_MOLECULE_FEATS } ' ]
207209 is_molecule_types : Bool [f'n { IS_MOLECULE_TYPES } ' ]
210+ src_tgt_atom_indices : Int ['n 2' ]
208211 token_bonds : Bool ['n n' ]
209212 molecule_atom_indices : List [int | None ] | None = None
210213 distogram_atom_indices : List [int | None ] | None = None
@@ -308,7 +311,10 @@ def molecule_to_atom_input(
308311
309312 # for every molecule, build the bonds id matrix and add to `atompair_ids`
310313
311- for idx , (mol , is_first_mol_in_chain , is_chainable_biomolecule , offset ) in enumerate (zip (molecules , is_first_mol_in_chains , is_chainable_biomolecules , offsets )):
314+ prev_mol = None
315+ prev_src_tgt_atom_indices = None
316+
317+ for idx , (mol , is_first_mol_in_chain , is_chainable_biomolecule , src_tgt_atom_indices , offset ) in enumerate (zip (molecules , is_first_mol_in_chains , is_chainable_biomolecules , i .src_tgt_atom_indices , offsets )):
312318
313319 coordinates = []
314320 updates = []
@@ -349,15 +355,21 @@ def molecule_to_atom_input(
349355 atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
350356
351357 # if is chainable biomolecule
352- # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
358+ # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
353359
354360 if is_chainable_biomolecule and not is_first_mol_in_chain :
355361
362+ _ , last_atom_index = prev_src_tgt_atom_indices
363+ first_atom_index , _ = src_tgt_atom_indices
364+
365+ src_atom_offset = offset + first_atom_index
366+ tgt_atom_offset = offset - last_atom_index
356367
357- atompair_ids [offset , offset - 1 ] = 1
358- atompair_ids [offset - 1 , offset ] = 1
368+ atompair_ids [src_atom_offset , tgt_atom_offset ] = 1
369+ atompair_ids [tgt_atom_offset , src_atom_offset ] = 1
359370
360- last_mol = mol
371+ prev_mol = mol
372+ prev_src_tgt_atom_indices = src_tgt_atom_indices
361373
362374 # atom_inputs
363375
@@ -538,6 +550,7 @@ def alphafold3_input_to_molecule_input(
538550
539551 distogram_atom_indices = []
540552 molecule_atom_indices = []
553+ src_tgt_atom_indices = []
541554
542555 for protein in proteins :
543556 mol_peptides , protein_entries = map_int_or_string_indices_to_mol (HUMAN_AMINO_ACIDS , protein , chain = True , return_entries = True )
@@ -546,6 +559,8 @@ def alphafold3_input_to_molecule_input(
546559 distogram_atom_indices .extend ([entry ['token_center_atom_idx' ] for entry in protein_entries ])
547560 molecule_atom_indices .extend ([entry ['distogram_atom_idx' ] for entry in protein_entries ])
548561
562+ src_tgt_atom_indices .extend ([[entry ['first_atom_idx' ], entry ['last_atom_idx' ]] for entry in protein_entries ])
563+
549564 protein_ids = maybe_string_to_int (HUMAN_AMINO_ACIDS , protein ) + protein_offset
550565 molecule_ids .append (protein_ids )
551566
@@ -560,6 +575,11 @@ def alphafold3_input_to_molecule_input(
560575 mol_seq , ss_rna_entries = map_int_or_string_indices_to_mol (RNA_NUCLEOTIDES , seq , chain = True , return_entries = True )
561576 mol_ss_rnas .append (mol_seq )
562577
578+ distogram_atom_indices .extend ([entry ['token_center_atom_idx' ] for entry in ss_rna_entries ])
579+ molecule_atom_indices .extend ([entry ['distogram_atom_idx' ] for entry in ss_rna_entries ])
580+
581+ src_tgt_atom_indices .extend ([[entry ['first_atom_idx' ], entry ['last_atom_idx' ]] for entry in ss_rna_entries ])
582+
563583 rna_ids = maybe_string_to_int (RNA_NUCLEOTIDES , seq ) + rna_offset
564584 molecule_ids .append (rna_ids )
565585
@@ -569,6 +589,11 @@ def alphafold3_input_to_molecule_input(
569589 mol_seq , ss_dna_entries = map_int_or_string_indices_to_mol (DNA_NUCLEOTIDES , seq , chain = True , return_entries = True )
570590 mol_ss_dnas .append (mol_seq )
571591
592+ distogram_atom_indices .extend ([entry ['token_center_atom_idx' ] for entry in ss_dna_entries ])
593+ molecule_atom_indices .extend ([entry ['distogram_atom_idx' ] for entry in ss_dna_entries ])
594+
595+ src_tgt_atom_indices .extend ([[entry ['first_atom_idx' ], entry ['last_atom_idx' ]] for entry in ss_dna_entries ])
596+
572597 dna_ids = maybe_string_to_int (DNA_NUCLEOTIDES , seq ) + dna_offset
573598 molecule_ids .append (dna_ids )
574599
@@ -814,6 +839,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
814839 molecule_atom_indices = tensor (molecule_atom_indices )
815840 molecule_atom_indices = pad_to_len (molecule_atom_indices , num_tokens , value = - 1 )
816841
842+ src_tgt_atom_indices = tensor (src_tgt_atom_indices )
843+ src_tgt_atom_indices = pad_to_len (src_tgt_atom_indices , num_tokens , value = - 1 , dim = - 2 )
844+
817845 # todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
818846
819847 atom_pos = i .atom_pos
@@ -830,6 +858,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
830858 additional_molecule_feats = additional_molecule_feats ,
831859 additional_token_feats = default (i .additional_token_feats , torch .zeros (num_tokens , 2 )),
832860 is_molecule_types = is_molecule_types ,
861+ src_tgt_atom_indices = src_tgt_atom_indices ,
833862 atom_pos = atom_pos ,
834863 templates = i .templates ,
835864 msa = i .msa ,
0 commit comments