@@ -104,7 +104,8 @@ def compute_lrmsd_fast(self, lzone=None, method='svd', check=True):
104104 if lzone is None :
105105 resData = self .compute_lzone (save_file = False )
106106 elif not os .path .isfile (lzone ):
107- resData = self .compute_lzone (save_file = True , filename = lzone )
107+ resData = self .compute_lzone (
108+ save_file = True , filename = lzone )
108109 else :
109110 resData = self .read_zone (lzone )
110111
@@ -140,7 +141,7 @@ def compute_lrmsd_fast(self, lzone=None, method='svd', check=True):
140141 self .ref , resData , return_not_in_zone = True )
141142
142143 xyz_decoy_short = superpose_selection (
143- xyz_decoy_short , xyz_decoy_long , xyz_ref_long , method )
144+ xyz_decoy_short , xyz_decoy_long , xyz_ref_long , method )
144145
145146 # compute the RMSD
146147 return self .get_rmsd (xyz_decoy_short , xyz_ref_short )
@@ -162,19 +163,25 @@ def compute_lzone(self, save_file=True, filename=None):
162163 dict: definition of the zone.
163164 """
164165 sql_ref = pdb2sql (self .ref )
165- nA = len (sql_ref .get ('x,y,z' , chainID = 'A' ))
166- nB = len (sql_ref .get ('x,y,z' , chainID = 'B' ))
166+ chains = list (sql_ref .get_chains ())
167+ if len (chains ) != 2 :
168+ raise ValueError (
169+ 'exactly two chains are needed for lrmsd calculation but we found %d' % len (chains ), chains )
170+
171+ nA = len (sql_ref .get ('x,y,z' , chainID = chains [0 ]))
172+ nB = len (sql_ref .get ('x,y,z' , chainID = chains [1 ]))
167173
168174 # detect which chain is the longest
169- long_chain = 'A'
175+ long_chain = chains [ 0 ]
170176 if nA < nB :
171- long_chain = 'B'
177+ long_chain = chains [ 1 ]
172178
173179 # extract data about the residue
174180 data_test = [
175181 tuple (data ) for data in sql_ref .get (
176182 'chainID,resSeq' ,
177183 chainID = long_chain )]
184+
178185 data_test = sorted (set (data_test ))
179186
180187 # close the sql
@@ -199,6 +206,7 @@ def compute_lzone(self, save_file=True, filename=None):
199206 if chain not in resData .keys ():
200207 resData [chain ] = []
201208 resData [chain ].append (num )
209+
202210 return resData
203211
204212 ##########################################################################
@@ -246,8 +254,6 @@ def compute_irmsd_fast(
246254 # read the izone file
247255 if izone is None :
248256 resData = self .compute_izone (cutoff , save_file = False )
249- # elif not os.path.isfile(izone):
250- # resData = self.compute_izone(cutoff,save_file=True,filename=izone)
251257 else :
252258 resData = self .read_zone (izone )
253259
@@ -264,12 +270,14 @@ def compute_irmsd_fast(
264270
265271 # extract the xyz
266272 else :
267- xyz_contact_decoy = self .get_xyz_zone_backbone (self .decoy , resData )
268- xyz_contact_ref = self .get_xyz_zone_backbone (self .ref , resData )
273+ xyz_contact_decoy = self .get_xyz_zone_backbone (
274+ self .decoy , resData )
275+ xyz_contact_ref = self .get_xyz_zone_backbone (
276+ self .ref , resData )
269277
270278 # superpose the fragments
271279 xyz_contact_decoy = superpose_selection (xyz_contact_decoy ,
272- xyz_contact_decoy , xyz_contact_ref , method )
280+ xyz_contact_decoy , xyz_contact_ref , method )
273281
274282 # return the RMSD
275283 return self .get_rmsd (xyz_contact_decoy , xyz_contact_ref )
@@ -285,9 +293,15 @@ def compute_izone(self, cutoff=10.0, save_file=True, filename=None):
285293 Returns:
286294 dict: i-zone definition
287295 """
296+
288297 sql_ref = interface (self .ref )
298+ chains = list (sql_ref .get_chains ())
299+ if len (chains ) != 2 :
300+ raise ValueError (
301+ 'exactly two chains are needed for irmsd calculation but we found %d' % len (chains ), chains )
302+
289303 contact_ref = sql_ref .get_contact_atoms (
290- cutoff = cutoff , extend_to_residue = True )
304+ cutoff = cutoff , extend_to_residue = True , chain1 = chains [ 0 ], chain2 = chains [ 1 ] )
291305
292306 index_contact_ref = []
293307 for _ , v in contact_ref .items ():
@@ -488,30 +502,37 @@ def compute_lrmsd_pdb2sql(self, exportpath=None, method='svd'):
488502 sql_decoy = pdb2sql (self .decoy , sqlfile = 'decoy.db' )
489503 sql_ref = pdb2sql (self .ref , sqlfile = 'ref.db' )
490504
505+ # get the chains
506+ chains_decoy = sql_decoy .get_chains ()
507+ chains_ref = sql_ref .get_chains ()
508+
509+ if chains_decoy != chains_ref :
510+ raise ValueError (
511+ 'Chains are different in decoy and reference structure' )
512+
513+ chain1 = chains_decoy [0 ]
514+ chain2 = chains_decoy [1 ]
515+
491516 # extract the pos of chains A
492517 xyz_decoy_A = np .array (
493- sql_decoy .get (
494- 'x,y,z' ,
495- chainID = 'A' ,
496- name = backbone ))
497- xyz_ref_A = np .array (sql_ref .get ('x,y,z' , chainID = 'A' , name = backbone ))
518+ sql_decoy .get ('x,y,z' , chainID = chain1 , name = backbone ))
519+ xyz_ref_A = np .array (sql_ref .get (
520+ 'x,y,z' , chainID = chain1 , name = backbone ))
498521
499522 # extract the pos of chains B
500523 xyz_decoy_B = np .array (
501- sql_decoy .get (
502- 'x,y,z' ,
503- chainID = 'B' ,
504- name = backbone ))
505- xyz_ref_B = np .array (sql_ref .get ('x,y,z' , chainID = 'B' , name = backbone ))
524+ sql_decoy .get ('x,y,z' , chainID = chain2 , name = backbone ))
525+ xyz_ref_B = np .array (sql_ref .get (
526+ 'x,y,z' , chainID = chain2 , name = backbone ))
506527
507528 # check the lengthes
508529 if len (xyz_decoy_A ) != len (xyz_ref_A ):
509530 xyz_decoy_A , xyz_ref_A = self .get_identical_atoms (
510- sql_decoy , sql_ref , 'A' )
531+ sql_decoy , sql_ref , chain1 )
511532
512533 if len (xyz_decoy_B ) != len (xyz_ref_B ):
513534 xyz_decoy_B , xyz_ref_B = self .get_identical_atoms (
514- sql_decoy , sql_ref , 'B' )
535+ sql_decoy , sql_ref , chain2 )
515536
516537 # detect which chain is the longest
517538 nA , nB = len (xyz_decoy_A ), len (xyz_decoy_B )
@@ -565,7 +586,8 @@ def compute_lrmsd_pdb2sql(self, exportpath=None, method='svd'):
565586 xyz_decoy += tr_decoy
566587
567588 # rotate decoy
568- xyz_decoy = transform .rotate (xyz_decoy , U , center = self .origin )
589+ xyz_decoy = transform .rotate (
590+ xyz_decoy , U , center = self .origin )
569591
570592 # update the sql database
571593 sql_decoy .update_column ('x' , xyz_decoy [:, 0 ])
@@ -602,8 +624,10 @@ def get_identical_atoms(db1, db2, chain):
602624 """
603625 backbone = ['CA' , 'C' , 'N' , 'O' ]
604626 # get data
605- data1 = db1 .get ('chainID,resSeq,name' , chainID = chain , name = backbone )
606- data2 = db2 .get ('chainID,resSeq,name' , chainID = chain , name = backbone )
627+ data1 = db1 .get ('chainID,resSeq,name' ,
628+ chainID = chain , name = backbone )
629+ data2 = db2 .get ('chainID,resSeq,name' ,
630+ chainID = chain , name = backbone )
607631
608632 # tuplify
609633 data1 = [tuple (d1 ) for d1 in data1 ]
@@ -668,11 +692,23 @@ def compute_irmsd_pdb2sql(
668692 sql_decoy = interface (self .decoy )
669693 sql_ref = interface (self .ref )
670694
695+ # get the chains
696+ chains_decoy = sql_decoy .get_chains ()
697+ chains_ref = sql_ref .get_chains ()
698+
699+ if chains_decoy != chains_ref :
700+ raise ValueError (
701+ 'Chains are different in decoy and reference structure' )
702+
671703 # get the contact atoms
672704 if izone is None :
705+
673706 contact_ref = sql_ref .get_contact_atoms (
674707 cutoff = cutoff ,
675- extend_to_residue = True )
708+ extend_to_residue = True ,
709+ chain1 = chains_ref [0 ],
710+ chain2 = chains_ref [1 ])
711+
676712 index_contact_ref = []
677713 for v in contact_ref .values ():
678714 index_contact_ref += v
@@ -683,7 +719,8 @@ def compute_irmsd_pdb2sql(
683719 sql_ref , izone , return_only_backbone_atoms = True )
684720
685721 # get the xyz and atom identifier of the decoy contact atoms
686- xyz_contact_ref = sql_ref .get ('x,y,z' , rowID = index_contact_ref )
722+ xyz_contact_ref = sql_ref .get (
723+ 'x,y,z' , rowID = index_contact_ref )
687724 data_contact_ref = sql_ref .get (
688725 'chainID,resSeq,resName,name' ,
689726 rowID = index_contact_ref )
@@ -720,7 +757,8 @@ def compute_irmsd_pdb2sql(
720757 # check that we still have atoms in both chains
721758 chain_decoy = list (
722759 set (sql_decoy .get ('chainID' , rowID = index_contact_decoy )))
723- chain_ref = list (set (sql_ref .get ('chainID' , rowID = index_contact_ref )))
760+ chain_ref = list (
761+ set (sql_ref .get ('chainID' , rowID = index_contact_ref )))
724762
725763 if len (chain_decoy ) < 1 or len (chain_ref ) < 1 :
726764 raise ValueError (
@@ -737,9 +775,9 @@ def compute_irmsd_pdb2sql(
737775 # get the ideql rotation matrix
738776 # to superimpose the A chains
739777 rot_mat = get_rotation_matrix (
740- xyz_contact_decoy ,
741- xyz_contact_ref ,
742- method = method )
778+ xyz_contact_decoy ,
779+ xyz_contact_ref ,
780+ method = method )
743781
744782 # rotate the entire fragment
745783 xyz_contact_decoy = transform .rotate (
@@ -752,8 +790,10 @@ def compute_irmsd_pdb2sql(
752790 if exportpath is not None :
753791
754792 # update the sql database
755- sql_decoy .update_xyz (xyz_contact_decoy , rowID = index_contact_decoy )
756- sql_ref .update_xyz (xyz_contact_ref , rowID = index_contact_ref )
793+ sql_decoy .update_xyz (
794+ xyz_contact_decoy , rowID = index_contact_decoy )
795+ sql_ref .update_xyz (
796+ xyz_contact_ref , rowID = index_contact_ref )
757797
758798 sql_decoy .exportpdb (
759799 exportpath + '/irmsd_decoy.pdb' ,
@@ -845,8 +885,8 @@ def compute_fnat_pdb2sql(self, cutoff=5.0):
845885 """
846886
847887 # create the sql
848- sql_decoy = interface (self .decoy )
849- sql_ref = interface (self .ref )
888+ sql_decoy = interface (self .decoy , fix_chainID = True )
889+ sql_ref = interface (self .ref , fix_chainID = True )
850890
851891 # get the contact atoms
852892 residue_pairs_decoy = sql_decoy .get_contact_residues (
@@ -865,7 +905,8 @@ def compute_fnat_pdb2sql(self, cutoff=5.0):
865905 data_pair_ref += [(resA , resB ) for resB in resB_list ]
866906
867907 # find the umber of residue that ref and decoys hace in common
868- nCommon = len (set (data_pair_ref ).intersection (data_pair_decoy ))
908+ nCommon = len (
909+ set (data_pair_ref ).intersection (data_pair_decoy ))
869910
870911 # normalize
871912 fnat = nCommon / len (data_pair_ref )
@@ -962,17 +1003,20 @@ def get_data_zone_backbone(pdb_file, resData, return_not_in_zone=False):
9621003 name = line [12 :16 ].strip ()
9631004
9641005 backbone = ['C' , 'CA' , 'N' , 'O' ]
1006+
9651007 if chainID in resData .keys ():
9661008
9671009 if resSeq in resData [chainID ] and name in backbone :
9681010 data_in_zone .append ((chainID , resSeq , name ))
9691011
9701012 elif resSeq not in resData [chainID ] and name in backbone :
971- data_not_in_zone .append ((chainID , resSeq , name ))
1013+ data_not_in_zone .append (
1014+ (chainID , resSeq , name ))
9721015
9731016 else :
9741017 if name in backbone :
975- data_not_in_zone .append ((chainID , resSeq , name ))
1018+ data_not_in_zone .append (
1019+ (chainID , resSeq , name ))
9761020
9771021 if return_not_in_zone :
9781022 return set (data_in_zone ), set (data_not_in_zone )
@@ -1119,7 +1163,8 @@ def compute_DockQScore(fnat, lrmsd, irmsd, d1=8.5, d2=1.5):
11191163 def scale_rms (rms , d ):
11201164 return (1. / (1 + (rms / d )** 2 ))
11211165
1122- dockq = 1. / 3 * (fnat + scale_rms (lrmsd , d1 ) + scale_rms (irmsd , d2 ))
1166+ dockq = 1. / 3 * \
1167+ (fnat + scale_rms (lrmsd , d1 ) + scale_rms (irmsd , d2 ))
11231168 return round (dockq , 6 )
11241169
11251170 ##########################################################################
0 commit comments