@@ -634,7 +634,6 @@ def add_entity_id(length):
634634
635635 for ss_dna in i .ss_dna :
636636 add_entity_id (len (ss_dna ))
637- curr_id += 1
638637
639638 for ds_dna in i .ds_dna :
640639 add_entity_id (len (ds_dna ) * 2 )
@@ -646,14 +645,52 @@ def add_entity_id(length):
646645
647646 entity_ids = tensor (entity_ids ).long ()
648647
648+ # sym_id
649+
650+ sym_ids = []
651+ curr_id = 0
652+
653+ def add_sym_id (length , reset = False ):
654+ nonlocal curr_id
655+
656+ if reset :
657+ curr_id = 0
658+
659+ sym_ids .extend ([curr_id ] * length )
660+ curr_id += 1
661+
662+ for protein_chain_num_tokens in num_protein_tokens :
663+ add_sym_id (protein_chain_num_tokens )
664+
665+ for ss_rna in i .ss_rna :
666+ add_sym_id (len (ss_rna ), reset = True )
667+
668+ for ds_rna in i .ds_rna :
669+ add_sym_id (len (ds_rna ), reset = True )
670+ add_sym_id (len (ds_rna ))
671+
672+ for ss_dna in i .ss_dna :
673+ add_sym_id (len (ss_dna ), reset = True )
674+
675+ for ds_dna in i .ds_dna :
676+ add_sym_id (len (ds_dna ), reset = True )
677+ add_sym_id (len (ds_dna ))
678+
679+ for l in mol_ligands :
680+ add_sym_id (l .GetNumAtoms (), reset = True )
681+
682+ add_sym_id (num_metal_ions , reset = True )
683+
684+ sym_ids = tensor (sym_ids ).long ()
685+
649686 # concat for all of additional_molecule_feats
650687
651688 additional_molecule_feats = torch .stack ((
652689 molecule_ids ,
653690 torch .arange (num_tokens ),
654691 asym_ids ,
655692 entity_ids ,
656- torch . zeros ( num_tokens ). long (),
693+ sym_ids
657694 ), dim = - 1 )
658695
659696 # molecule atom indices
0 commit comments