Skip to content

Commit 94cf2ca

Browse files
authored
Merge pull request #50 from DeepRank/issue49
Issue49
2 parents 83c05e1 + 1cbfa42 commit 94cf2ca

File tree

2 files changed

+96
-43
lines changed

2 files changed

+96
-43
lines changed

pdb2sql/StructureSimilarity.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def compute_lrmsd_fast(self, lzone=None, method='svd', check=True):
104104
if lzone is None:
105105
resData = self.compute_lzone(save_file=False)
106106
elif not os.path.isfile(lzone):
107-
resData = self.compute_lzone(save_file=True, filename=lzone)
107+
resData = self.compute_lzone(
108+
save_file=True, filename=lzone)
108109
else:
109110
resData = self.read_zone(lzone)
110111

@@ -140,7 +141,7 @@ def compute_lrmsd_fast(self, lzone=None, method='svd', check=True):
140141
self.ref, resData, return_not_in_zone=True)
141142

142143
xyz_decoy_short = superpose_selection(
143-
xyz_decoy_short, xyz_decoy_long, xyz_ref_long, method)
144+
xyz_decoy_short, xyz_decoy_long, xyz_ref_long, method)
144145

145146
# compute the RMSD
146147
return self.get_rmsd(xyz_decoy_short, xyz_ref_short)
@@ -162,19 +163,25 @@ def compute_lzone(self, save_file=True, filename=None):
162163
dict: definition of the zone.
163164
"""
164165
sql_ref = pdb2sql(self.ref)
165-
nA = len(sql_ref.get('x,y,z', chainID='A'))
166-
nB = len(sql_ref.get('x,y,z', chainID='B'))
166+
chains = list(sql_ref.get_chains())
167+
if len(chains) != 2:
168+
raise ValueError(
169+
'exactly two chains are needed for lrmsd calculation but we found %d' % len(chains), chains)
170+
171+
nA = len(sql_ref.get('x,y,z', chainID=chains[0]))
172+
nB = len(sql_ref.get('x,y,z', chainID=chains[1]))
167173

168174
# detect which chain is the longest
169-
long_chain = 'A'
175+
long_chain = chains[0]
170176
if nA < nB:
171-
long_chain = 'B'
177+
long_chain = chains[1]
172178

173179
# extract data about the residue
174180
data_test = [
175181
tuple(data) for data in sql_ref.get(
176182
'chainID,resSeq',
177183
chainID=long_chain)]
184+
178185
data_test = sorted(set(data_test))
179186

180187
# close the sql
@@ -199,6 +206,7 @@ def compute_lzone(self, save_file=True, filename=None):
199206
if chain not in resData.keys():
200207
resData[chain] = []
201208
resData[chain].append(num)
209+
202210
return resData
203211

204212
##########################################################################
@@ -246,8 +254,6 @@ def compute_irmsd_fast(
246254
# read the izone file
247255
if izone is None:
248256
resData = self.compute_izone(cutoff, save_file=False)
249-
# elif not os.path.isfile(izone):
250-
# resData = self.compute_izone(cutoff,save_file=True,filename=izone)
251257
else:
252258
resData = self.read_zone(izone)
253259

@@ -264,12 +270,14 @@ def compute_irmsd_fast(
264270

265271
# extract the xyz
266272
else:
267-
xyz_contact_decoy = self.get_xyz_zone_backbone(self.decoy, resData)
268-
xyz_contact_ref = self.get_xyz_zone_backbone(self.ref, resData)
273+
xyz_contact_decoy = self.get_xyz_zone_backbone(
274+
self.decoy, resData)
275+
xyz_contact_ref = self.get_xyz_zone_backbone(
276+
self.ref, resData)
269277

270278
# superpose the fragments
271279
xyz_contact_decoy = superpose_selection(xyz_contact_decoy,
272-
xyz_contact_decoy, xyz_contact_ref, method)
280+
xyz_contact_decoy, xyz_contact_ref, method)
273281

274282
# return the RMSD
275283
return self.get_rmsd(xyz_contact_decoy, xyz_contact_ref)
@@ -285,9 +293,15 @@ def compute_izone(self, cutoff=10.0, save_file=True, filename=None):
285293
Returns:
286294
dict: i-zone definition
287295
"""
296+
288297
sql_ref = interface(self.ref)
298+
chains = list(sql_ref.get_chains())
299+
if len(chains) != 2:
300+
raise ValueError(
301+
'exactly two chains are needed for irmsd calculation but we found %d' % len(chains), chains)
302+
289303
contact_ref = sql_ref.get_contact_atoms(
290-
cutoff=cutoff, extend_to_residue=True)
304+
cutoff=cutoff, extend_to_residue=True, chain1=chains[0], chain2=chains[1])
291305

292306
index_contact_ref = []
293307
for _, v in contact_ref.items():
@@ -488,30 +502,37 @@ def compute_lrmsd_pdb2sql(self, exportpath=None, method='svd'):
488502
sql_decoy = pdb2sql(self.decoy, sqlfile='decoy.db')
489503
sql_ref = pdb2sql(self.ref, sqlfile='ref.db')
490504

505+
# get the chains
506+
chains_decoy = sql_decoy.get_chains()
507+
chains_ref = sql_ref.get_chains()
508+
509+
if chains_decoy != chains_ref:
510+
raise ValueError(
511+
'Chains are different in decoy and reference structure')
512+
513+
chain1 = chains_decoy[0]
514+
chain2 = chains_decoy[1]
515+
491516
# extract the pos of chains A
492517
xyz_decoy_A = np.array(
493-
sql_decoy.get(
494-
'x,y,z',
495-
chainID='A',
496-
name=backbone))
497-
xyz_ref_A = np.array(sql_ref.get('x,y,z', chainID='A', name=backbone))
518+
sql_decoy.get('x,y,z', chainID=chain1, name=backbone))
519+
xyz_ref_A = np.array(sql_ref.get(
520+
'x,y,z', chainID=chain1, name=backbone))
498521

499522
# extract the pos of chains B
500523
xyz_decoy_B = np.array(
501-
sql_decoy.get(
502-
'x,y,z',
503-
chainID='B',
504-
name=backbone))
505-
xyz_ref_B = np.array(sql_ref.get('x,y,z', chainID='B', name=backbone))
524+
sql_decoy.get('x,y,z', chainID=chain2, name=backbone))
525+
xyz_ref_B = np.array(sql_ref.get(
526+
'x,y,z', chainID=chain2, name=backbone))
506527

507528
# check the lengthes
508529
if len(xyz_decoy_A) != len(xyz_ref_A):
509530
xyz_decoy_A, xyz_ref_A = self.get_identical_atoms(
510-
sql_decoy, sql_ref, 'A')
531+
sql_decoy, sql_ref, chain1)
511532

512533
if len(xyz_decoy_B) != len(xyz_ref_B):
513534
xyz_decoy_B, xyz_ref_B = self.get_identical_atoms(
514-
sql_decoy, sql_ref, 'B')
535+
sql_decoy, sql_ref, chain2)
515536

516537
# detect which chain is the longest
517538
nA, nB = len(xyz_decoy_A), len(xyz_decoy_B)
@@ -565,7 +586,8 @@ def compute_lrmsd_pdb2sql(self, exportpath=None, method='svd'):
565586
xyz_decoy += tr_decoy
566587

567588
# rotate decoy
568-
xyz_decoy = transform.rotate(xyz_decoy, U, center=self.origin)
589+
xyz_decoy = transform.rotate(
590+
xyz_decoy, U, center=self.origin)
569591

570592
# update the sql database
571593
sql_decoy.update_column('x', xyz_decoy[:, 0])
@@ -602,8 +624,10 @@ def get_identical_atoms(db1, db2, chain):
602624
"""
603625
backbone = ['CA', 'C', 'N', 'O']
604626
# get data
605-
data1 = db1.get('chainID,resSeq,name', chainID=chain, name=backbone)
606-
data2 = db2.get('chainID,resSeq,name', chainID=chain, name=backbone)
627+
data1 = db1.get('chainID,resSeq,name',
628+
chainID=chain, name=backbone)
629+
data2 = db2.get('chainID,resSeq,name',
630+
chainID=chain, name=backbone)
607631

608632
# tuplify
609633
data1 = [tuple(d1) for d1 in data1]
@@ -668,11 +692,23 @@ def compute_irmsd_pdb2sql(
668692
sql_decoy = interface(self.decoy)
669693
sql_ref = interface(self.ref)
670694

695+
# get the chains
696+
chains_decoy = sql_decoy.get_chains()
697+
chains_ref = sql_ref.get_chains()
698+
699+
if chains_decoy != chains_ref:
700+
raise ValueError(
701+
'Chains are different in decoy and reference structure')
702+
671703
# get the contact atoms
672704
if izone is None:
705+
673706
contact_ref = sql_ref.get_contact_atoms(
674707
cutoff=cutoff,
675-
extend_to_residue=True)
708+
extend_to_residue=True,
709+
chain1=chains_ref[0],
710+
chain2=chains_ref[1])
711+
676712
index_contact_ref = []
677713
for v in contact_ref.values():
678714
index_contact_ref += v
@@ -683,7 +719,8 @@ def compute_irmsd_pdb2sql(
683719
sql_ref, izone, return_only_backbone_atoms=True)
684720

685721
# get the xyz and atom identifier of the decoy contact atoms
686-
xyz_contact_ref = sql_ref.get('x,y,z', rowID=index_contact_ref)
722+
xyz_contact_ref = sql_ref.get(
723+
'x,y,z', rowID=index_contact_ref)
687724
data_contact_ref = sql_ref.get(
688725
'chainID,resSeq,resName,name',
689726
rowID=index_contact_ref)
@@ -720,7 +757,8 @@ def compute_irmsd_pdb2sql(
720757
# check that we still have atoms in both chains
721758
chain_decoy = list(
722759
set(sql_decoy.get('chainID', rowID=index_contact_decoy)))
723-
chain_ref = list(set(sql_ref.get('chainID', rowID=index_contact_ref)))
760+
chain_ref = list(
761+
set(sql_ref.get('chainID', rowID=index_contact_ref)))
724762

725763
if len(chain_decoy) < 1 or len(chain_ref) < 1:
726764
raise ValueError(
@@ -737,9 +775,9 @@ def compute_irmsd_pdb2sql(
737775
# get the ideql rotation matrix
738776
# to superimpose the A chains
739777
rot_mat = get_rotation_matrix(
740-
xyz_contact_decoy,
741-
xyz_contact_ref,
742-
method=method)
778+
xyz_contact_decoy,
779+
xyz_contact_ref,
780+
method=method)
743781

744782
# rotate the entire fragment
745783
xyz_contact_decoy = transform.rotate(
@@ -752,8 +790,10 @@ def compute_irmsd_pdb2sql(
752790
if exportpath is not None:
753791

754792
# update the sql database
755-
sql_decoy.update_xyz(xyz_contact_decoy, rowID=index_contact_decoy)
756-
sql_ref.update_xyz(xyz_contact_ref, rowID=index_contact_ref)
793+
sql_decoy.update_xyz(
794+
xyz_contact_decoy, rowID=index_contact_decoy)
795+
sql_ref.update_xyz(
796+
xyz_contact_ref, rowID=index_contact_ref)
757797

758798
sql_decoy.exportpdb(
759799
exportpath + '/irmsd_decoy.pdb',
@@ -845,8 +885,8 @@ def compute_fnat_pdb2sql(self, cutoff=5.0):
845885
"""
846886

847887
# create the sql
848-
sql_decoy = interface(self.decoy)
849-
sql_ref = interface(self.ref)
888+
sql_decoy = interface(self.decoy, fix_chainID=True)
889+
sql_ref = interface(self.ref, fix_chainID=True)
850890

851891
# get the contact atoms
852892
residue_pairs_decoy = sql_decoy.get_contact_residues(
@@ -865,7 +905,8 @@ def compute_fnat_pdb2sql(self, cutoff=5.0):
865905
data_pair_ref += [(resA, resB) for resB in resB_list]
866906

867907
# find the umber of residue that ref and decoys hace in common
868-
nCommon = len(set(data_pair_ref).intersection(data_pair_decoy))
908+
nCommon = len(
909+
set(data_pair_ref).intersection(data_pair_decoy))
869910

870911
# normalize
871912
fnat = nCommon / len(data_pair_ref)
@@ -962,17 +1003,20 @@ def get_data_zone_backbone(pdb_file, resData, return_not_in_zone=False):
9621003
name = line[12:16].strip()
9631004

9641005
backbone = ['C', 'CA', 'N', 'O']
1006+
9651007
if chainID in resData.keys():
9661008

9671009
if resSeq in resData[chainID] and name in backbone:
9681010
data_in_zone.append((chainID, resSeq, name))
9691011

9701012
elif resSeq not in resData[chainID] and name in backbone:
971-
data_not_in_zone.append((chainID, resSeq, name))
1013+
data_not_in_zone.append(
1014+
(chainID, resSeq, name))
9721015

9731016
else:
9741017
if name in backbone:
975-
data_not_in_zone.append((chainID, resSeq, name))
1018+
data_not_in_zone.append(
1019+
(chainID, resSeq, name))
9761020

9771021
if return_not_in_zone:
9781022
return set(data_in_zone), set(data_not_in_zone)
@@ -1119,7 +1163,8 @@ def compute_DockQScore(fnat, lrmsd, irmsd, d1=8.5, d2=1.5):
11191163
def scale_rms(rms, d):
11201164
return(1. / (1 + (rms / d)**2))
11211165

1122-
dockq = 1. / 3 * (fnat + scale_rms(lrmsd, d1) + scale_rms(irmsd, d2))
1166+
dockq = 1. / 3 * \
1167+
(fnat + scale_rms(lrmsd, d1) + scale_rms(irmsd, d2))
11231168
return round(dockq, 6)
11241169

11251170
##########################################################################

pdb2sql/interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from .pdb2sqlcore import pdb2sql
55

6+
67
class interface(pdb2sql):
78

89
def __init__(self, pdb, **kwargs):
@@ -73,6 +74,12 @@ def get_contact_atoms(
7374
else:
7475
chainIDs = [chain1, chain2]
7576

77+
chains = self.get_chains()
78+
for c in chainIDs:
79+
if c not in chains:
80+
raise ValueError(
81+
'chain %s not found in the structure' % c)
82+
7683
xyz = dict()
7784
index = dict()
7885
resName = dict()
@@ -265,7 +272,8 @@ def get_contact_residues(
265272
residue_contact_pairs[data1] = set()
266273

267274
# get the res info of the atom in the other chain
268-
data2 = self.get('chainID,resSeq,resName', rowID=atoms2)
275+
data2 = self.get(
276+
'chainID,resSeq,resName', rowID=atoms2)
269277

270278
# store that in the dict without double
271279
for resData in data2:
@@ -300,4 +308,4 @@ def get_contact_residues(
300308
residue_contact[chain] = sorted(
301309
set([tuple(resData) for resData in data[chain]]))
302310

303-
return residue_contact
311+
return residue_contact

0 commit comments

Comments
 (0)