Skip to content

Commit f274d1b

Browse files
committed
implemented ensemble and minor bug fix
1 parent d73b101 commit f274d1b

File tree

5 files changed

+67
-28
lines changed

5 files changed

+67
-28
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "faerun-python"]
2+
path = faerun-python
3+
url = [email protected]:Truman-Xu/faerun-python.git

faerun-python

Submodule faerun-python added at 3f95a69

sampledock/SnD/tmap_plotter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def df_to_faerun(df,x,y,s,t):
6262
f = Faerun(view="front", coords=False)
6363
f.add_scatter(
6464
# No space in the string allowed for the name, use underscore!!
65-
# Cannot start with a number, it has to be a letter!! Weird Bug!!
66-
# My guess is that the string is to be converted to a variable name,
65+
# Cannot start with a number, it has to be a letter!!
66+
# the string is to be converted to a variable name,
6767
# therefore it has to be compatible with python variable naming scheme
6868
"SampleDock",
6969
{
@@ -94,8 +94,6 @@ def df_to_faerun(df,x,y,s,t):
9494
)
9595
# The first character of the name has to be a letter!
9696
f.add_tree("SnD_Tree", {"from": s, "to": t}, point_helper="SampleDock")
97-
f.plot("SampleDock"+'_space', # name of the .html file
98-
template="smiles")
9997
print('Plotting finished')
10098
return f
10199

@@ -126,4 +124,6 @@ def df_to_faerun(df,x,y,s,t):
126124
df['s'] = s
127125
df['t'] = t
128126
df.to_csv(os.path.join(outpath,"props.csv"),index = False)
129-
df_to_faerun(df)
127+
f = df_to_faerun(df)
128+
f.plot(os.path.join(outpath,"SampleDock"+'_space'), # name of the .html file
129+
template="smiles")

sampledock/__main__.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import subprocess
1515
import pickle
16-
from rdkit import rdBase
16+
from rdkit import rdBase, Chem
1717
## Disable rdkit Logs
1818
rdBase.DisableLog('rdApp.error')
1919
from .jtvae import Vocab, JTNNVAE
@@ -74,27 +74,35 @@
7474
dock(ligs, docking_dir, prmfile, p.docking_prm, p.npose, p.prefix)
7575
ranked_poses = sort_pose(docking_dir, p.sort_by, p.prefix)
7676
save_pose(ranked_poses, design_dir)
77-
77+
7878
## Generate new design list
79-
for energy, name, mol in ranked_poses:
80-
smi = mol.GetProp('SMILES')
81-
design_list = []
82-
try:
83-
print('[INFO]: Generating new designs \t', end = '\r')
84-
sys.stdout.flush()
85-
design_list = jtvae.smiles_gen(smi, p.ndesign)
86-
# go to the second best candidate if the best does not give any return
87-
except KeyError as key:
88-
print('[KeyError]',key,'is not part of the vocabulary')
89-
continue
90-
91-
if len(design_list) != 0:
92-
break
93-
94-
else:
95-
print('Current design (%s) has no offspring; trying the next one \r'%name)
96-
97-
print("[INFO]: Cycle %s: %s %s kcal/mol"%(j, smi, energy)+'\t'*6)
79+
if p.ensemble > 1:
80+
top_smi_list = [Chem.MolToSmiles(mol) for _, _, mol in ranked_poses[:p.ensemble]]
81+
smi = jtvae.find_ensemble(top_smi_list)
82+
design_list = jtvae.smiles_gen(smi, p.ndesign)
83+
best_score = ranked_poses[0][0]
84+
print("[INFO]: Cycle %s: %s Best Score: %s kcal/mol"%(j, smi, best_score)+'\t'*6)
85+
else:
86+
for energy, name, mol in ranked_poses:
87+
smi = mol.GetProp('SMILES')
88+
try:
89+
print('[INFO]: Generating new designs \t', end = '\r')
90+
sys.stdout.flush()
91+
# get new design list for the nex cycle
92+
design_list = jtvae.smiles_gen(smi, p.ndesign)
93+
# This is due to difference in parsing of SMILES (especially rings)
94+
## TODO: Convert sampledock to OOP structure and use the vectors directly
95+
except KeyError as key:
96+
print('[KeyError]',key,'is not part of the vocabulary (the model was not trained with this scaffold)')
97+
continue
98+
# if there are offspring designs, break the loop
99+
if len(design_list) != 0:
100+
break
101+
# go to the next candidate if the current one does not give any return
102+
else:
103+
print('Current design (%s) has no offspring; trying the next one \r'%name)
104+
105+
print("[INFO]: Cycle %s: %s %s kcal/mol"%(j, smi, energy)+'\t'*6)
98106

99107
print("\n", p.ncycle, "cycles of design finished. Starting post-processing.")
100108
# Create post-process working directory
@@ -117,6 +125,7 @@
117125
pickle.dump((x,y,s,t),f)
118126
# Create tmap on faerun
119127
f = df_to_faerun(allscores,x,y,s,t)
120-
128+
f.plot("SampleDock"+'_space', path = postproc_wd, # name and path of the .html file
129+
template="smiles")
121130
with open(os.path.join(postproc_wd,'SampleDock.faerun'), 'wb') as handle:
122131
pickle.dump(f.create_python_data(), handle, protocol=pickle.HIGHEST_PROTOCOL)

sampledock/jtvae/jtnn_vae.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def smiles_gen(self,smiles,ndesign,prob_decode = False):
7171
## Convert smiles to one-hot encoding (altered function from original code)
7272
x_tree, x_mol = self.encode_single_smiles(smiles)
7373

74-
## Encode one-hots to z-mean and log var. Following Mueller et al.
74+
## Encode one-hots to z-mean and log variance. Following Mueller et al.
7575
tree_mean = self.T_mean(x_tree)
7676
tree_log_var = -torch.abs(self.T_var(x_tree))
7777
mol_mean = self.G_mean(x_mol)
@@ -89,6 +89,32 @@ def smiles_gen(self,smiles,ndesign,prob_decode = False):
8989
smiles_list.append(smilesout)
9090
return smiles_list
9191

92+
def find_ensemble(self,smiles_list):
93+
z_tree = []
94+
z_mol = []
95+
for smi in smiles_list:
96+
try:
97+
x_tree, x_mol = self.encode_single_smiles(smi)
98+
# This is due to difference in parsing of SMILES (especially rings)
99+
## TODO: Convert sampledock to OOP structure and use the vectors directly
100+
except KeyError as key:
101+
print('[KeyError]',key,'is not part of the vocabulary (the model was not trained with this scaffold)')
102+
continue
103+
tree_mean = self.T_mean(x_tree)
104+
tree_log_var = -torch.abs(self.T_var(x_tree))
105+
mol_mean = self.G_mean(x_mol)
106+
mol_log_var = -torch.abs(self.G_var(x_mol))
107+
108+
z_tree.append(self.z_vec(tree_mean, tree_log_var))
109+
z_mol.append(self.z_vec(mol_mean, mol_log_var))
110+
111+
z_tree = torch.cat(z_tree)
112+
z_mol = torch.cat(z_mol)
113+
114+
return self.decode(z_tree.mean(0).reshape((1,self.latent_size)),
115+
z_mol.mean(0).reshape((1,self.latent_size)),
116+
False)
117+
92118
def encode_latent(self, jtenc_holder, mpn_holder):
93119
tree_vecs, _ = self.jtnn(*jtenc_holder)
94120
mol_vecs = self.mpn(*mpn_holder)

0 commit comments

Comments
 (0)