Skip to content

Commit 6b1f4f7

Browse files
committed
extended ensemble feature
1 parent cf41799 commit 6b1f4f7

File tree

3 files changed

+77
-31
lines changed

3 files changed

+77
-31
lines changed

sampledock/SnD/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .docking import dock, sort_pose, save_pose
22
from .pocket_prepare import prep_prm
33
from .sampler_util import hyperparam_loader, create_wd, smiles_to_sdfile
4+
from .generator import single_generator, distributed_generator
45
from .post_process import mkdf, combine_designs
56
from .tmap_plotter import LSH_Convert, tree_coords, df_to_faerun

sampledock/SnD/generator.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import sys
2-
import os
3-
4-
from jtvae import Vocab, JTNNVAE
2+
from sampledock import Vocab, JTNNVAE
53
import torch
64

75
def jtvae_loader(params):
@@ -11,6 +9,61 @@ def jtvae_loader(params):
119
jtvae.load_state_dict(torch.load(params.model_loc, map_location=device))
1210
return jtvae
1311

12+
def single_generator(ranked_poses, ndesign, jtvae):
13+
'''
14+
Generate next cycle of designs based on the top scoring pose
15+
arg:
16+
- ranked_poses:tuple (energy:float, design:str, best_pose:rdkit.Mol)
17+
- ndesign:int Number of designs to be generated
18+
19+
return design_list:list(SMILES:str)
20+
'''
21+
for energy, name, mol in ranked_poses:
22+
smi = mol.GetProp('SMILES')
23+
design_list = _generator(smi, ndesign, jtvae)
24+
if len(design_list) > 0:
25+
break
26+
# go to the next candidate if the current one does not give any return
27+
else:
28+
print('Current design (%s) has no offspring; trying the next one \r'%name)
29+
30+
return design_list # return a list of SMILES
31+
32+
def _generator(smi, ndesign, jtvae):
33+
try:
34+
print('[INFO]: Generating new designs \t', end = '\r')
35+
sys.stdout.flush()
36+
# get new design list for the nex cycle
37+
design_list = jtvae.smiles_gen(smi, ndesign)
38+
return design_list # return a list of SMILES
39+
40+
# This is due to difference in parsing of SMILES (especially rings)
41+
## TODO: Convert sampledock to OOP structure and use the vectors directly
42+
except KeyError as key:
43+
print('[KeyError]',key,'is not part of the vocabulary (the model was not trained with this scaffold)')
44+
45+
46+
47+
def distributed_generator(ranked_poses, nseeds, sub_ndesign, jtvae):
48+
'''
49+
Generate next cycle of designs based on the top n scoring poses
50+
arg:
51+
- ranked_poses:tuple (energy:float, design:str, best_pose:rdkit.Mol)
52+
- nseeds:int Number of the top designs being used
53+
- sub_ndesign:int Number of designs to be generated for each seed
54+
55+
return design_list:list(SMILES:str)
56+
'''
57+
design_list = []
58+
for i_seed in range(nseeds):
59+
energy, name, mol = ranked_poses[i_seed]
60+
smi = mol.GetProp('SMILES')
61+
cur_design_list = _generator(smi, sub_ndesign, jtvae)
62+
if len(cur_design_list) > 0:
63+
design_list.extend(cur_design_list)
64+
65+
return design_list
66+
1467
if __name__ == "__main__":
1568
from sampler import hyperparam_loader
1669
import argparse
@@ -29,5 +82,4 @@ def jtvae_loader(params):
2982
with open(a.output,'w') as f:
3083
for smi in design_list:
3184
print(smi)
32-
f.write(smi+'\n')
33-
f.close()
85+
f.write(smi+'\n')

sampledock/__main__.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
## TODO: this will definitely need to be restructured as the additional features have cluterred code.
12
import argparse
23

34
parser = argparse.ArgumentParser(prog='sampledock')
@@ -21,6 +22,7 @@
2122
from .SnD import prep_prm
2223
from .SnD import dock, sort_pose, save_pose
2324
from .SnD import hyperparam_loader, create_wd, smiles_to_sdfile
25+
from .SnD import single_generator, distributed_generator
2426
from .SnD import combine_designs, mkdf
2527
from .SnD import LSH_Convert, tree_coords, df_to_faerun
2628
# Load hyper parameters
@@ -44,10 +46,14 @@
4446
design_list = jtvae.smiles_gen(p.seed_smi, p.ndesign)
4547
except KeyError as err:
4648
print('[KeyError]',err,
47-
'does not exist in the current JTVAE model vocabulary "%s" (the training set of the model did not contain this structure),'%p.vocab_loc,
49+
'does not exist in the current JTVAE model vocabulary "%s" \
50+
(the training set of the model did not contain this structure),'%p.vocab_loc,
4851
'thus "%s" failed to initialize the model as seeding molecule!'%p.seed_smi)
4952
exit()
5053

54+
# Check if design generations are ditributed among the top n designs (nseeds)
55+
sub_ndesign = int(p.ndesign)//int(p.nseeds) if p.nseeds > 1 else False
56+
5157
# Create working directory
5258
wd = create_wd(a.output,p.receptor_name)
5359

@@ -75,35 +81,22 @@
7581
ranked_poses = sort_pose(docking_dir, p.sort_by, p.prefix)
7682
save_pose(ranked_poses, design_dir)
7783

84+
## Report the top design
85+
top_energy, top_name, top_mol = ranked_poses[0]
86+
top_smi = top_mol.GetProp('SMILES')
87+
print("[INFO]: Cycle %s: %s %s kcal/mol"%(j, top_smi, top_energy)+'\t'*6)
88+
7889
## Generate new design list
79-
if p.ensemble > 1:
80-
top_smi_list = [Chem.MolToSmiles(mol) for _, _, mol in ranked_poses[:p.ensemble]]
90+
if sub_ndesign:
91+
design_list = distributed_generator(ranked_poses, p.nseeds, sub_ndesign, jtvae)
92+
elif p.ensemble > 1:
93+
top_smi_list = [mol.GetProp('SMILES') for _, _, mol in ranked_poses[:p.ensemble]]
8194
smi = jtvae.find_ensemble(top_smi_list)
8295
design_list = jtvae.smiles_gen(smi, p.ndesign)
83-
best_score = ranked_poses[0][0]
84-
print("[INFO]: Cycle %s: %s Best Score: %s kcal/mol"%(j, smi, best_score)+'\t'*6)
8596
else:
86-
for energy, name, mol in ranked_poses:
87-
smi = mol.GetProp('SMILES')
88-
try:
89-
print('[INFO]: Generating new designs \t', end = '\r')
90-
sys.stdout.flush()
91-
# get new design list for the nex cycle
92-
design_list = jtvae.smiles_gen(smi, p.ndesign)
93-
# This is due to difference in parsing of SMILES (especially rings)
94-
## TODO: Convert sampledock to OOP structure and use the vectors directly
95-
except KeyError as key:
96-
print('[KeyError]',key,'is not part of the vocabulary (the model was not trained with this scaffold)')
97-
continue
98-
# if there are offspring designs, break the loop
99-
if len(design_list) != 0:
100-
break
101-
# go to the next candidate if the current one does not give any return
102-
else:
103-
print('Current design (%s) has no offspring; trying the next one \r'%name)
104-
105-
print("[INFO]: Cycle %s: %s %s kcal/mol"%(j, smi, energy)+'\t'*6)
106-
97+
design_list = single_generator(ranked_poses, p.ndesign, jtvae)
98+
99+
107100
print("\n", p.ncycle, "cycles of design finished. Starting post-processing.")
108101
# Create post-process working directory
109102
postproc_wd = os.path.join(wd, "All_Designs_Processed")

0 commit comments

Comments
 (0)