55from typing import Type , Literal , Callable , List , Any
66
77import torch
8+ from torch import tensor
89import torch .nn .functional as F
910import einx
1011
@@ -169,7 +170,7 @@ def molecule_to_atom_input(
169170 else :
170171 atom_lens .append (num_atoms )
171172
172- atom_lens = torch . tensor (atom_lens )
173+ atom_lens = tensor (atom_lens )
173174 total_atoms = atom_lens .sum ().item ()
174175
175176 # molecule_atom_lens
@@ -194,7 +195,7 @@ def molecule_to_atom_input(
194195
195196 atom_ids .append (atom_index [atom_symbol ])
196197
197- atom_ids = torch . tensor (atom_ids , dtype = torch .long )
198+ atom_ids = tensor (atom_ids , dtype = torch .long )
198199
199200 # handle maybe atompair embeds
200201
@@ -228,8 +229,8 @@ def molecule_to_atom_input(
228229
229230 updates .extend ([bond_id , bond_id ])
230231
231- coordinates = torch . tensor (coordinates ).long ()
232- updates = torch . tensor (updates ).long ()
232+ coordinates = tensor (coordinates ).long ()
233+ updates = tensor (updates ).long ()
233234
234235 mol_atompair_ids = einx .set_at ('[h w], c [2], c -> [h w]' , mol_atompair_ids , coordinates , updates )
235236
@@ -269,7 +270,7 @@ def molecule_to_atom_input(
269270 pos = mol .GetConformer ().GetAtomPosition (i )
270271 all_atom_pos .append ([pos .x , pos .y , pos .z ])
271272
272- all_atom_pos_tensor = torch . tensor (all_atom_pos )
273+ all_atom_pos_tensor = tensor (all_atom_pos )
273274
274275 dist_matrix = torch .cdist (all_atom_pos_tensor , all_atom_pos_tensor )
275276
@@ -281,9 +282,9 @@ def molecule_to_atom_input(
281282 offset += num_atoms
282283
283284 atom_input = AtomInput (
284- atom_inputs = torch . tensor (atom_inputs , dtype = torch .float ),
285+ atom_inputs = tensor (atom_inputs , dtype = torch .float ),
285286 atompair_inputs = atompair_inputs ,
286- molecule_atom_lens = torch . tensor (atom_lens , dtype = torch .long ),
287+ molecule_atom_lens = tensor (atom_lens , dtype = torch .long ),
287288 molecule_ids = mol_input .molecule_ids ,
288289 additional_token_feats = mol_input .additional_token_feats ,
289290 additional_molecule_feats = mol_input .additional_molecule_feats ,
@@ -380,25 +381,27 @@ def maybe_string_to_int(
380381
381382 index = {symbol : i for i , symbol in enumerate (entries .keys ())}
382383
383- return torch . tensor ([index [c ] for c in indices ]).long ()
384+ return tensor ([index [c ] for c in indices ]).long ()
384385
385386@typecheck
386387def alphafold3_input_to_molecule_input (
387388 alphafold3_input : Alphafold3Input
388389) -> MoleculeInput :
389390
390- ss_rnas = list (alphafold3_input .ss_rna )
391- ss_dnas = list (alphafold3_input .ss_dna )
391+ i = alphafold3_input
392+
393+ ss_rnas = list (i .ss_rna )
394+ ss_dnas = list (i .ss_dna )
392395
393396 # any double stranded nucleic acids is added to single stranded lists with its reverse complement
394397 # rc stands for reverse complement
395398
396- for seq in alphafold3_input .ds_rna :
399+ for seq in i .ds_rna :
397400 rc_fn = partial (reverse_complement , nucleic_acid_type = 'rna' ) if isinstance (seq , str ) else reverse_complement_tensor
398401 rc_seq = rc_fn (seq )
399402 ss_rnas .extend ([seq , rc_seq ])
400403
401- for seq in alphafold3_input .ds_dna :
404+ for seq in i .ds_dna :
402405 rc_fn = partial (reverse_complement , nucleic_acid_type = 'dna' ) if isinstance (seq , str ) else reverse_complement_tensor
403406 rc_seq = rc_fn (seq )
404407 ss_dnas .extend ([seq , rc_seq ])
@@ -414,7 +417,7 @@ def alphafold3_input_to_molecule_input(
414417
415418 # convert all proteins to a List[Mol] of each peptide
416419
417- proteins = alphafold3_input .proteins
420+ proteins = i .proteins
418421 mol_proteins = []
419422 protein_entries = []
420423 molecule_atom_indices = []
@@ -497,7 +500,7 @@ def alphafold3_input_to_molecule_input(
497500
498501 arange = torch .arange (num_tokens )[:, None ]
499502
500- molecule_types_lens_cumsum = torch . tensor ([0 , * molecule_type_token_lens ]).cumsum (dim = - 1 )
503+ molecule_types_lens_cumsum = tensor ([0 , * molecule_type_token_lens ]).cumsum (dim = - 1 )
501504 left , right = molecule_types_lens_cumsum [:- 1 ], molecule_types_lens_cumsum [1 :]
502505
503506 is_molecule_types = (arange >= left ) & (arange < right )
@@ -552,8 +555,8 @@ def alphafold3_input_to_molecule_input(
552555
553556 updates .extend ([True , True ])
554557
555- coordinates = torch . tensor (coordinates ).long ()
556- updates = torch . tensor (updates ).bool ()
558+ coordinates = tensor (coordinates ).long ()
559+ updates = tensor (updates ).bool ()
557560
558561 has_bond = einx .set_at ('[h w], c [2], c -> [h w]' , has_bond , coordinates , updates )
559562
@@ -590,7 +593,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
590593
591594 unflattened_atom_parent_ids = [([asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([* num_protein_atoms , * num_ss_rna_atoms , * num_ss_dna_atoms , * num_ligand_atoms , num_metal_ions ])]
592595
593- atom_parent_ids = torch . tensor (flatten (unflattened_atom_parent_ids ))
596+ atom_parent_ids = tensor (flatten (unflattened_atom_parent_ids ))
594597
595598 # constructing the additional_molecule_feats
596599 # which is in turn used to derive relative positions
@@ -609,27 +612,57 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
609612
610613 unflattened_asym_ids = [([asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([* num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens , * num_ligand_tokens , num_metal_ions ])]
611614
612- asym_ids = torch .tensor (flatten (unflattened_asym_ids ))
615+ asym_ids = tensor (flatten (unflattened_asym_ids ))
616+
617+ # entity ids
618+
619+ entity_ids = []
620+ curr_id = 0
621+
622+ def add_entity_id (length ):
623+ nonlocal curr_id
624+ entity_ids .extend ([curr_id ] * length )
625+ curr_id += 1
626+
627+ add_entity_id (sum (num_protein_tokens ))
628+
629+ for ss_rna in i .ss_rna :
630+ add_entity_id (len (ss_rna ))
631+
632+ for ds_rna in i .ds_rna :
633+ add_entity_id (len (ds_rna ) * 2 )
634+
635+ for ss_dna in i .ss_dna :
636+ add_entity_id (len (ss_dna ))
637+ curr_id += 1
638+
639+ for ds_dna in i .ds_dna :
640+ add_entity_id (len (ds_dna ) * 2 )
641+
642+ for l in mol_ligands :
643+ add_entity_id (l .GetNumAtoms ())
644+
645+ add_entity_id (num_metal_ions )
646+
647+ entity_ids = tensor (entity_ids ).long ()
613648
614649 # concat for all of additional_molecule_feats
615650
616651 additional_molecule_feats = torch .stack ((
617652 molecule_ids ,
618653 torch .arange (num_tokens ),
619654 asym_ids ,
620- torch . zeros ( num_tokens ). long () ,
655+ entity_ids ,
621656 torch .zeros (num_tokens ).long (),
622657 ), dim = - 1 )
623658
624659 # molecule atom indices
625660
626- molecule_atom_indices = torch . tensor (molecule_atom_indices )
661+ molecule_atom_indices = tensor (molecule_atom_indices )
627662 molecule_atom_indices = pad_to_len (molecule_atom_indices , num_tokens , value = - 1 )
628663
629664 # create molecule input
630665
631- i = alphafold3_input
632-
633666 molecule_input = MoleculeInput (
634667 molecules = molecules ,
635668 molecule_token_pool_lens = token_pool_lens ,
0 commit comments