Skip to content

Commit dbb5608

Browse files
authored
Merge pull request #40 from Waztom/XCOS2
Addded MCs filter to XCOS and removed Scores 2-3
2 parents 93ebb1b + ceabf2d commit dbb5608

File tree

1 file changed

+39
-203
lines changed

1 file changed

+39
-203
lines changed

src/python/pipelines/xchem/xcos.py

Lines changed: 39 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from rdkit.Chem.FeatMaps import FeatMaps
2626
from rdkit.Chem import AllChem, rdShapeHelpers
2727
from rdkit import RDConfig
28+
from rdkit.Chem import rdFMCS
2829

2930
import os, argparse
3031

3132
import numpy as np
3233
import pandas as pd
33-
from sklearn.neighbors import NearestNeighbors
3434

3535
from datetime import datetime
3636

@@ -40,8 +40,6 @@
4040
field_XCosRefMols = "XCos_RefMols"
4141
field_XCosNumHits = "XCos_NumHits"
4242
field_XCosScore1 = "XCos_Score1"
43-
field_XCosScore2 = "XCos_Score2"
44-
field_XCosScore3 = "XCos_Score3"
4543

4644

4745
date = datetime.today().strftime('%Y-%m-%d')
@@ -101,168 +99,15 @@ def getFeatureMapScore(small_m, large_m, score_mode=FeatMaps.FeatMapScoreMode.Al
10199
except ZeroDivisionError:
102100
return 0
103101

104-
def getNumberfeats(mol):
105-
106-
featLists = []
107-
rawFeats = fdef.GetFeaturesForMol(mol)
108-
# filter that list down to only include the ones we're intereted in
109-
featLists.append([f for f in rawFeats if f.GetFamily() in keep])
110-
111-
return len(featLists)
112-
113-
114-
def getFeatureMapXCOS(mol_list):
115-
allFeats = []
116-
for m in mol_list:
117-
118-
rawFeats = fdef.GetFeaturesForMol(m)
119-
featDeats = [(f.GetType(),
120-
f.GetPos().x,
121-
f.GetPos().y,
122-
f.GetPos().z) for f in rawFeats if f.GetFamily() in keep]
123-
124-
allFeats.append(featDeats)
125-
126-
127-
feature_map_df = pd.DataFrame([t for lst in allFeats for t in lst],
128-
columns =['featType', 'x', 'y', 'z'])
129-
130-
return feature_map_df
131-
132-
133-
def getFeatureAgg(feature_map_df, rad_thresh):
134-
135-
# Group data into unique feature types
136-
grouped_df = feature_map_df.groupby('featType')
137-
138-
data_to_add = []
139-
140-
for group_name, df_group in grouped_df:
141-
142-
# Reset index df
143-
df_group = df_group.reset_index()
144-
145-
if len(df_group) == 1:
146-
147-
data_to_add.append(df_group)
148-
149-
if len(df_group) > 1:
150-
151-
# Get feature name
152-
feat_name = df_group.featType.unique()[0]
153-
154-
# Use radius neighbours to find features within
155-
# spere with radius thresh
156-
neigh = NearestNeighbors(radius=rad_thresh)
157-
158-
while len(df_group) > 0:
159-
160-
neigh.fit(df_group[['x','y','z']])
161-
162-
# Get distances and indices of neigbours within radius threshold
163-
rng = neigh.radius_neighbors()
164-
neigh_dist = rng[0][0]
165-
neigh_indices = rng[1][0]
166-
167-
# Append the first index - NB clustering done relative to index 0
168-
neigh_indices = list(np.append(0, neigh_indices))
169-
170-
# Calculate average x,y,z coords for features in similar loc
171-
x_avg = np.mean(df_group.iloc[neigh_indices].x)
172-
y_avg = np.mean(df_group.iloc[neigh_indices].y)
173-
z_avg = np.mean(df_group.iloc[neigh_indices].z)
174-
175-
# Add feature with average x, y and z values
176-
new_row = [(feat_name, x_avg, y_avg, z_avg)]
177-
178-
cluster_df = pd.DataFrame(data=new_row, columns = ['featType', 'x', 'y', 'z'])
179-
180-
data_to_add.append(cluster_df)
181-
182-
# Remove indices of clustered neigbours
183-
df_group = df_group.drop(df_group.index[neigh_indices])
184-
185-
# Create single DF from list of dfs
186-
clustered_df = pd.concat(data_to_add)
187-
188-
return clustered_df
189-
190102

191103
# This is the main XCOS function
192-
def getReverseScores(clustered_df, mols, frags, no_clustered_feats, rad_threshold, COS_threshold, writer):
104+
def getReverseScores(mols, frags, COS_threshold, writer):
193105

194106
for mol in mols:
195107

196108
# Get the bits
197109
compound_bits = getBits(mol)
198110

199-
# We are going to include a feature mapping score, where the
200-
# number of features of the compound matching the clustered feats
201-
# within a threshold are found
202-
203-
# Get feature map of compound bits as df
204-
feature_map_bits = getFeatureMapXCOS(compound_bits)
205-
206-
# Group data into unique feature types
207-
grouped_df = feature_map_bits.groupby('featType')
208-
209-
no_feats_matched = []
210-
dist_feats_matched = []
211-
212-
# Use radius neighbours to find features within
213-
# sphere with radius thresh
214-
neigh = NearestNeighbors(radius=rad_threshold)
215-
216-
# Loop through grouped features
217-
for group_name, df_group in grouped_df:
218-
219-
# Get feat name
220-
feat_name = df_group.featType.unique()[0]
221-
222-
# Get similar feats from cluster df
223-
cluster_test = clustered_df[clustered_df.featType == feat_name]
224-
225-
# Reset index df
226-
df_group = df_group.reset_index()
227-
228-
if len(cluster_test) == 1:
229-
230-
# Calculate distances
231-
x1_sub_x2 = (cluster_test.iloc[0].x - df_group.iloc[0].x) ** 2
232-
y1_sub_y2 = (cluster_test.iloc[0].y - df_group.iloc[0].y) ** 2
233-
z1_sub_z2 = (cluster_test.iloc[0].z - df_group.iloc[0].z) ** 2
234-
235-
diff_sum = x1_sub_x2 + y1_sub_y2 + z1_sub_z2
236-
237-
dist = diff_sum ** 0.5
238-
239-
if dist < rad_threshold:
240-
# Let's get the number of feats matched
241-
no_feats_matched.append(1)
242-
243-
# Let's get the distance of the feats matched
244-
dist_feats_matched.append([dist])
245-
246-
if len(cluster_test) > 1:
247-
neigh.fit(cluster_test[['x', 'y', 'z']])
248-
249-
while len(df_group) > 0:
250-
# Get distances and indices of neigbours within radius threshold
251-
feat_coords = [[df_group.iloc[0].x, df_group.iloc[0].y, df_group.iloc[0].z]]
252-
rng = neigh.radius_neighbors(feat_coords)
253-
254-
neigh_dist = rng[0][0]
255-
neigh_indices = rng[1][0]
256-
257-
# Let's get the number of feats matched
258-
no_feats_matched.append(len(neigh_indices))
259-
260-
# Remove index 0 of df_group
261-
df_group = df_group.drop(df_group.index[0])
262-
263-
# Get total number of feat matches
264-
no_feats = np.sum(no_feats_matched)
265-
266111
all_scores = []
267112

268113
for bit in compound_bits:
@@ -271,53 +116,62 @@ def getReverseScores(clustered_df, mols, frags, no_clustered_feats, rad_threshol
271116
no_bit_atoms = bit.GetNumAtoms()
272117

273118
scores = []
119+
274120
for frag_mol in frags:
121+
275122
# NB reverse SuCOS scoring
276123
fm_score = getFeatureMapScore(bit, frag_mol)
277124
fm_score = np.clip(fm_score, 0, 1)
278125
# Change van der Waals radius scale for stricter overlay
279126
protrude_dist = rdShapeHelpers.ShapeProtrudeDist(bit, frag_mol, allowReordering=False, vdwScale=0.2)
280127
protrude_dist = np.clip(protrude_dist, 0, 1)
281128

282-
reverse_SuCOS_score = 0.5 * fm_score + 0.5 * (1 - protrude_dist)
129+
# Get frag name for linking to score
130+
frag_name = frag_mol.GetProp('_Name').strip('Mpro-')
131+
132+
# Check if MCS yield > 0 atoms
133+
mcs_match = rdFMCS.FindMCS([bit,frag_mol],ringMatchesRingOnly=True,matchValences=True)
134+
135+
# Get number of atoms in MCS match found
136+
no_mcs_atoms = Chem.MolFromSmarts(mcs_match.smartsString).GetNumAtoms()
283137

284-
# Get number of feats from bit for scaling score
285-
no_bit_feats = getNumberfeats(bit)
138+
if no_mcs_atoms == 0:
286139

287-
# Get some info and append to list
288-
frag_name = frag_mol.GetProp('_Name').strip('Mpro-')
140+
scores.append((frag_name, 0, no_bit_atoms))
141+
142+
if no_mcs_atoms > 0:
143+
144+
# NB reverse SuCOS scoring
145+
fm_score = getFeatureMapScore(bit, frag_mol)
146+
fm_score = np.clip(fm_score, 0, 1)
147+
148+
# Change van der Waals radius scale for stricter overlay
149+
protrude_dist = rdShapeHelpers.ShapeProtrudeDist(bit, frag_mol,
150+
allowReordering=False,
151+
vdwScale=0.2)
152+
protrude_dist = np.clip(protrude_dist, 0, 1)
153+
154+
reverse_SuCOS_score = 0.5 * fm_score + 0.5 * (1 - protrude_dist)
289155

290-
scores.append((frag_name, reverse_SuCOS_score, no_bit_atoms, no_bit_feats))
156+
scores.append((frag_name, reverse_SuCOS_score, no_bit_atoms))
291157

292158
all_scores.append(scores)
293159

294160
list_dfs = []
161+
295162
for score in all_scores:
296-
df = pd.DataFrame(data=score, columns=['Fragment', 'Score', 'No_bit_atoms', 'No_bit_feats'])
163+
164+
df = pd.DataFrame(data=score, columns=['Fragment', 'Score', 'No_bit_atoms'])
165+
297166
# Get maximum scoring fragment for bit match
298167
df = df[df['Score'] == df['Score'].max()]
299168
list_dfs.append(df)
300169

301170
final_df = pd.concat(list_dfs)
302171

303-
# Get total bit score and some denominator terms
304-
bits_score = (final_df.No_bit_atoms * final_df.Score).sum()
305-
total_atoms = final_df.No_bit_atoms.sum()
306-
feat_match_fraction = no_feats / no_clustered_feats
307-
308172
# Score 1: the score is scaled by the number of bit atoms
309-
score_1 = bits_score
310-
311-
# Score 2: the score is scaled by the number of bit atoms
312-
# penalised by the fraction of feats matched
313-
# the to total number feats clustered
314-
score_2 = score_1 * feat_match_fraction
315-
316-
# Score 3: the score is determined by the fraction of matching
317-
# features to the clustered features within a threshold. This
318-
# should yield similar values to Tim's Featurestein method?
319-
score_3 = feat_match_fraction
320-
173+
score_1 = (final_df.No_bit_atoms * final_df.Score).sum()
174+
321175
# Let's only get frags above a threshold
322176
final_df = final_df[final_df.Score > COS_threshold]
323177

@@ -331,8 +185,6 @@ def getReverseScores(clustered_df, mols, frags, no_clustered_feats, rad_threshol
331185
mol.SetProp(field_XCosRefMols, ','.join(all_frags))
332186
mol.SetIntProp(field_XCosNumHits, len(all_frags))
333187
mol.SetProp(field_XCosScore1, "{:.4f}".format(score_1))
334-
mol.SetProp(field_XCosScore2, "{:.4f}".format(score_2))
335-
mol.SetProp(field_XCosScore3, "{:.4f}".format(score_3))
336188

337189
# Write to file
338190
writer.write(mol)
@@ -352,19 +204,8 @@ def process(molecules, fragments, writer):
352204
else:
353205
utils.log('Using', len(frag_mol_list), 'fragments. No errors')
354206

355-
feature_map_df = getFeatureMapXCOS(frag_mol_list)
356-
utils.log('Feature map dataframe shape:', feature_map_df.shape)
357-
358-
# Set radius threshold
359-
rad_thresh = 1.5
360-
361-
# Aggregate features using nearest neigbours algo
362-
clustered_df = getFeatureAgg(feature_map_df, rad_thresh=rad_thresh)
363-
utils.log('Clustered dataframe shape:', clustered_df.shape)
364-
no_clustered_feats = len(clustered_df)
365-
366-
#clustered_df, mols, rad_threshold, COS_threshold, writer
367-
getReverseScores(clustered_df, molecules, frag_mol_list, no_clustered_feats, 1.0, 0.50, writer)
207+
#mols, frags, COS_threshold, writer
208+
getReverseScores(molecules, frag_mol_list, 0.40, writer)
368209

369210

370211
def main():
@@ -381,7 +222,6 @@ def main():
381222
parser.add_argument('--no-gzip', action='store_true', help='Do not compress the output (STDOUT is never compressed')
382223
parser.add_argument('--metrics', action='store_true', help='Write metrics')
383224

384-
385225
args = parser.parse_args()
386226
utils.log("XCos Args: ", args)
387227

@@ -394,15 +234,11 @@ def main():
394234
clsMappings[field_XCosRefMols] = "java.lang.String"
395235
clsMappings[field_XCosNumHits] = "java.lang.Integer"
396236
clsMappings[field_XCosScore1] = "java.lang.Float"
397-
clsMappings[field_XCosScore2] = "java.lang.Float"
398-
clsMappings[field_XCosScore3] = "java.lang.Float"
237+
399238
fieldMetaProps.append({"fieldName":field_XCosRefMols, "values": {"source":source, "description":"XCos reference fragments"}})
400239
fieldMetaProps.append({"fieldName":field_XCosNumHits, "values": {"source":source, "description":"XCos number of hits"}})
401240
fieldMetaProps.append({"fieldName":field_XCosScore1, "values": {"source":source, "description":"XCos score 1"}})
402-
fieldMetaProps.append({"fieldName":field_XCosScore2, "values": {"source":source, "description":"XCos score 2"}})
403-
fieldMetaProps.append({"fieldName":field_XCosScore3, "values": {"source":source, "description":"XCos score 3"}})
404-
405-
241+
406242
frags_input,frags_suppl = rdkit_utils.default_open_input(args.fragments, args.fragments_format)
407243

408244
inputs_file, inputs_supplr = rdkit_utils.default_open_input(args.input, args.informat)

0 commit comments

Comments
 (0)