@@ -63,12 +63,15 @@ def crop_msa(feat, max_msa_depth=16384):
6363 return msa .astype ('int32' ), msa_mask , delection_mat
6464
6565
66- def get_padding_restype (ccd_id , ccd_preprocessed_dict , extra_feats = None ):
66+ def get_padding_restype (ccd_id , ccd_preprocessed_dict , extra_feats = None , is_poly_point = False ):
6767 if ccd_id in ccd_preprocessed_dict :
6868 refs = ccd_preprocessed_dict [ccd_id ] # O(1)
6969 if ccd_id in residue_constants .STANDARD_LIST :
7070 _residue_is_standard = True
71- pdb_atom_ids_list = POLYMER_STANDARD_RESI_ATOMS [ccd_id ] # NOTE: now is only support standard residue.
71+ if not is_poly_point :
72+ pdb_atom_ids_list = POLYMER_STANDARD_RESI_ATOMS [ccd_id ] # NOTE: now is only support standard residue.
73+ else :
74+ pdb_atom_ids_list = refs ['atom_ids' ]
7275 else :
7376 # for ligand/ion. ccd_id.
7477 _residue_is_standard = False
@@ -190,8 +193,14 @@ def get_inference_restype_mask(all_chain_features, ccd_preprocessed_dict, extra_
190193 frame_indice_offset = 0
191194 for type_chain_id , ccd_list in all_chain_features .items ():
192195 dtype , chain_id = type_chain_id .rsplit ('_' , 1 )
193- for ccd_id in ccd_list :
194- pad_feats = get_padding_restype (ccd_id , ccd_preprocessed_dict , extra_feats = extra_feats )
196+ for idx , ccd_id in enumerate (ccd_list ):
197+ is_poly_point = False
198+ if idx == len (ccd_list ) - 1 and dtype == 'protein' :
199+ is_poly_point = True
200+ elif idx == 0 and dtype in ['rna' , 'dna' ]:
201+ is_poly_point = True
202+ pad_feats = get_padding_restype (ccd_id , ccd_preprocessed_dict , extra_feats = extra_feats ,
203+ is_poly_point = is_poly_point )
195204 pad_feats ['ai_indice' ] = pad_feats ['ai_indice' ] + frame_indice_offset
196205 pad_feats ['bi_indice' ] = pad_feats ['bi_indice' ] + frame_indice_offset
197206 pad_feats ['ci_indice' ] = pad_feats ['ci_indice' ] + frame_indice_offset
0 commit comments