@@ -591,9 +591,12 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
591591 num_ss_dna_atoms = get_num_atoms_per_chain (mol_ss_dnas )
592592 num_ligand_atoms = [l .GetNumAtoms () for l in mol_ligands ]
593593
594- unflattened_atom_parent_ids = [([ asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([ * num_protein_atoms , * num_ss_rna_atoms , * num_ss_dna_atoms , * num_ligand_atoms , num_metal_ions ]) ]
594+ atom_counts = [* num_protein_atoms , * num_ss_rna_atoms , * num_ss_dna_atoms , * num_ligand_atoms , num_metal_ions ]
595595
596- atom_parent_ids = tensor (flatten (unflattened_atom_parent_ids ))
596+ atom_parent_ids = torch .repeat_interleave (
597+ torch .arange (len (atom_counts )),
598+ tensor (atom_counts )
599+ )
597600
598601 # constructing the additional_molecule_feats
599602 # which is in turn used to derive relative positions
@@ -610,78 +613,60 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
610613 num_ss_dna_tokens = [len (dna ) for dna in ss_dnas ]
611614 num_ligand_tokens = [l .GetNumAtoms () for l in mol_ligands ]
612615
613- unflattened_asym_ids = [([ asym_id ] * num_tokens ) for asym_id , num_tokens in enumerate ([ * num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens , * num_ligand_tokens , num_metal_ions ])]
616+ token_repeats = tensor ([ * num_protein_tokens , * num_ss_rna_tokens , * num_ss_dna_tokens , * num_ligand_tokens , num_metal_ions ])
614617
615- asym_ids = tensor (flatten (unflattened_asym_ids ))
618+ asym_ids = torch .repeat_interleave (
619+ torch .arange (len (token_repeats )),
620+ token_repeats
621+ )
616622
617623 # entity ids
618624
619- entity_ids = []
620- curr_id = 0
621-
622- def add_entity_id (length ):
623- nonlocal curr_id
624- entity_ids .extend ([curr_id ] * length )
625- curr_id += 1
626-
627- add_entity_id (sum (num_protein_tokens ))
628-
629- for ss_rna in i .ss_rna :
630- add_entity_id (len (ss_rna ))
631-
632- for ds_rna in i .ds_rna :
633- add_entity_id (len (ds_rna ) * 2 )
634-
635- for ss_dna in i .ss_dna :
636- add_entity_id (len (ss_dna ))
637-
638- for ds_dna in i .ds_dna :
639- add_entity_id (len (ds_dna ) * 2 )
640-
641- for l in mol_ligands :
642- add_entity_id (l .GetNumAtoms ())
643-
644- add_entity_id (num_metal_ions )
625+ unrepeated_entity_ids = tensor ([
626+ 0 ,
627+ * [* range (len (i .ss_rna ))],
628+ * [* range (len (i .ds_rna ))],
629+ * [* range (len (i .ss_dna ))],
630+ * [* range (len (i .ds_dna ))],
631+ * ([1 ] * len (mol_ligands )),
632+ 1
633+ ]).cumsum (dim = - 1 )
634+
635+ entity_id_counts = [
636+ sum (num_protein_tokens ),
637+ * [len (rna ) for rna in i .ss_rna ],
638+ * [len (rna ) * 2 for rna in i .ds_rna ],
639+ * [len (dna ) for dna in i .ss_dna ],
640+ * [len (dna ) * 2 for dna in i .ds_dna ],
641+ * num_ligand_tokens ,
642+ num_metal_ions
643+ ]
645644
646- entity_ids = tensor (entity_ids ). long ( )
645+ entity_ids = torch . repeat_interleave ( unrepeated_entity_ids , tensor (entity_id_counts ) )
647646
648647 # sym_id
649648
650- sym_ids = []
651- curr_id = 0
652-
653- def add_sym_id (length , reset = False ):
654- nonlocal curr_id
655-
656- if reset :
657- curr_id = 0
658-
659- sym_ids .extend ([curr_id ] * length )
660- curr_id += 1
661-
662- for protein_chain_num_tokens in num_protein_tokens :
663- add_sym_id (protein_chain_num_tokens )
664-
665- for ss_rna in i .ss_rna :
666- add_sym_id (len (ss_rna ), reset = True )
667-
668- for ds_rna in i .ds_rna :
669- add_sym_id (len (ds_rna ), reset = True )
670- add_sym_id (len (ds_rna ))
671-
672- for ss_dna in i .ss_dna :
673- add_sym_id (len (ss_dna ), reset = True )
674-
675- for ds_dna in i .ds_dna :
676- add_sym_id (len (ds_dna ), reset = True )
677- add_sym_id (len (ds_dna ))
678-
679- for l in mol_ligands :
680- add_sym_id (l .GetNumAtoms (), reset = True )
649+ unrepeated_sym_ids = [
650+ * [* range (len (i .proteins ))],
651+ * [* range (len (i .ss_rna ))],
652+ * [i for rna in i .ds_rna for i in range (2 )],
653+ * [* range (len (i .ss_dna ))],
654+ * [i for dna in i .ds_dna for i in range (2 )],
655+ * ([0 ] * len (mol_ligands )),
656+ 0
657+ ]
681658
682- add_sym_id (num_metal_ions , reset = True )
659+ sym_id_counts = [
660+ * num_protein_tokens ,
661+ * [len (rna ) for rna in i .ss_rna ],
662+ * flatten ([((len (rna ),) * 2 ) for rna in i .ds_rna ]),
663+ * [len (dna ) for dna in i .ss_dna ],
664+ * flatten ([((len (dna ),) * 2 ) for dna in i .ds_dna ]),
665+ * num_ligand_tokens ,
666+ num_metal_ions
667+ ]
683668
684- sym_ids = tensor (sym_ids ). long ( )
669+ sym_ids = torch . repeat_interleave ( tensor (unrepeated_sym_ids ), tensor ( sym_id_counts ) )
685670
686671 # concat for all of additional_molecule_feats
687672
0 commit comments