Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions analysis/gnina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from rdkit import Chem

def parse_gnina_log(filename, multiple=False):
if not multiple:
d={'Affinity':[],'RMSD':[],'CNNscore':[],'CNNaffinity':[],'CNNvariance':[]}
with open(filename,'r') as f:
for line in f:
for key in d:
if line[:len(key)]==key:
d[key].append(float(line.strip().split(' (kcal/mol)')[0].split(' ')[-1]))
return d
else:
d={'affinity':[],'intramol':[],'CNNscore':[],'CNNaffinity':[]}
with open(filename,'r') as f:
for line in f:
if line[:5]==' 1':
for i, key in enumerate(d):
d[key].append(float(line[5+i*13:18+i*13]))
return d

class Gnina:
def __init__(self, pdb_file):
self.gnina='/home/domain/data/prog/micromamba/envs/drugflow/bin/gnina'
self.tmp_ligand='/tmp/tmp_gnina.sdf'
self.pdb_file=pdb_file

def calculate_metrics(self,rdmol):

writer = Chem.SDWriter(self.tmp_ligand)
writer.write(mol=rdmol)
writer.close()

cmd=f'{self.gnina} -r {self.pdb_file} -l {self.tmp_ligand} --minimize'
output = os.popen(cmd, 'r')
d={'Affinity':[],'RMSD':[],'CNNscore':[],'CNNaffinity':[],'CNNvariance':[]}
for line in output:
for key in d:
if line.startswith(key):
d[key].append(float(line.strip().split(' ')[1]))
return d

def affinity(self,rdmol):
d=self.calculate_metrics(rdmol)
return -d['Affinity'][0]

def CNNaffinity(self,rdmol):
d=self.calculate_metrics(rdmol)
return d['CNNaffinity'][0]
1 change: 1 addition & 0 deletions inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def inpaint_ligand(model, pdb_file, n_samples, ligand, fix_atoms,
# Build mol objects
x = xh_lig[:, :model.x_dims].detach().cpu()
atom_type = xh_lig[:, model.x_dims:].argmax(1).detach().cpu()
lig_mask=lig_mask.detach().cpu()

molecules = []
for mol_pc in zip(utils.batch_to_list(x, lig_mask),
Expand Down
20 changes: 14 additions & 6 deletions optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def diversify_ligands(model, pocket, mols, timesteps,
# Build mol objects
x = out_lig[:, :model.x_dims].detach().cpu()
atom_type = out_lig[:, model.x_dims:].argmax(1).detach().cpu()
lig_mask=lig_mask.detach().cpu()

molecules = []
for mol_pc in zip(utils.batch_to_list(x, lig_mask),
Expand All @@ -153,7 +154,7 @@ def diversify_ligands(model, pocket, mols, timesteps,
parser.add_argument('--checkpoint', type=Path, default='checkpoints/crossdocked_fullatom_cond.ckpt')
parser.add_argument('--pdbfile', type=str, default='example/5ndu.pdb')
parser.add_argument('--ref_ligand', type=str, default='example/5ndu_linked_mols.sdf')
parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa'})
parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa','gnina'})
parser.add_argument('--timesteps', type=int, default=100)
parser.add_argument('--population_size', type=int, default=100)
parser.add_argument('--evolution_steps', type=int, default=10)
Expand Down Expand Up @@ -188,31 +189,38 @@ def diversify_ligands(model, pocket, mols, timesteps,
objective_function = MoleculeProperties().calculate_qed
elif args.objective == 'sa':
objective_function = MoleculeProperties().calculate_sa
elif args.objective == 'gnina':
from analysis.gnina import Gnina
objective_function = Gnina(args.pdbfile).affinity
else:
### IMPLEMENT YOUR OWN OBJECTIVE
### FUNCTIONS HERE
raise ValueError(f"Objective function {args.objective} not recognized.")

ref_mol = Chem.SDMolSupplier(args.ref_ligand)[0]
ref_mols = Chem.SDMolSupplier(args.ref_ligand)

# Store molecules in history dataframe
buffer = pd.DataFrame(columns=['generation', 'score', 'fate' 'mol', 'smiles'])

# Population initialization
buffer = buffer.append({'generation': 0,
for ref_mol in ref_mols:
buffer=buffer._append({'generation': 0,
'score': objective_function(ref_mol),
'fate': 'initial', 'mol': ref_mol,
'smiles': Chem.MolToSmiles(ref_mol)}, ignore_index=True)

for generation_idx in range(evolution_steps):

if generation_idx == 0:
molecules = buffer['mol'].tolist() * population_size
top_k_molecules=buffer.sort_values(by='score', ascending=False)['mol'].tolist()[:population_size]
molecules = top_k_molecules * ((population_size-1)//len(top_k_molecules)+1)
molecules=molecules[:population_size]
else:
# Select top k molecules from previous generation
previous_gen = buffer[buffer['generation'] == generation_idx]
top_k_molecules = previous_gen.nlargest(top_k, 'score')['mol'].tolist()
molecules = top_k_molecules * (population_size // top_k)
molecules = top_k_molecules * ((population_size-1)//top_k+1)
molecules=molecules[:population_size]

# Update the fate of selected top k molecules in the buffer
buffer.loc[buffer['generation'] == generation_idx, 'fate'] = 'survived'
Expand All @@ -235,7 +243,7 @@ def diversify_ligands(model, pocket, mols, timesteps,

# Evaluate and save molecules
for mol in molecules:
buffer = buffer.append({'generation': generation_idx + 1,
buffer = buffer._append({'generation': generation_idx + 1,
'score': objective_function(mol),
'fate': 'purged',
'mol': mol,
Expand Down