Skip to content

Commit 372af12

Browse files
committed
refactor: update diffusion logic, metrics utilities, and configs
- Model: Fix tensor comparison in EnVariationalDiffusion and enhance outpainting with new config parameters (d_threshold_f, w_b, all_frozen, etc.). - Generation: Update sampling logic in tasks_generate.py for conditional and hybrid guidance. - Metrics: Organize imports in geom_metrics.py and feature.py; add batch processing to PoseBusters runner; improve docstrings. - Scripts: Add CSV filtering to xtb_optimization.py; enhance compute_metrics.py with new validity checks; add standalone energy score computation script. - Configs: Update generation and outpainting configurations; refine dependencies in README.
1 parent e43bc71 commit 372af12

File tree

11 files changed

+384
-93
lines changed

11 files changed

+384
-93
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ For a more detailed installation, including setting up a conda environment and i
4747
conda install conda-forge::openbabel
4848
conda install xtb==6.7.1
4949
# install other libraries
50-
pip install fire seaborn decorator numpy==1.26.4 scipy rdkit-pypi posebusters==0.5.1 networkx matplotlib pandas scikit-learn tqdm pyyaml omegaconf ase morfeus cosymlib morfeus-ml wandb rmsd
50+
pip install fire seaborn decorator numpy scipy rdkit-pypi posebusters==0.5.1 networkx matplotlib pandas scikit-learn tqdm pyyaml omegaconf ase morfeus-ml morfeus-ml wandb rmsd
5151

5252
pip install hydra-core==1.* hydra-colorlog rootutils
5353

@@ -61,6 +61,9 @@ For a more detailed installation, including setting up a conda environment and i
6161
# Install the package. Use editable mode (-e) to make the MolCraftDiff CLI tool available.
6262
pip install -e .
6363

64+
# optional for some featurizer/metrics
65+
# this require numpy==1.24.*
66+
pip install cosymlib
6467

6568
Usage
6669
-----

configs/generate.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ defaults:
1010

1111
# run name, eg. for wandb logging
1212
name: "akatsuki"
13-
chkpt_directory: "/home/pregabalin/RF/MolecularDiffusion/trained_models/edm_pretrained/"
13+
chkpt_directory: "/home/pregabalin/RF/MolecularDiffusion/trained_models/edm_qm9vqm24//"
1414
atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi]
15-
diffusion_steps: 600
15+
diffusion_steps: 300
1616

1717
# tags to help you identify your experiments
1818
# you can overwrite this in experiment configs
@@ -21,3 +21,9 @@ diffusion_steps: 600
2121
# seed for random number generators in pytorch, numpy and python.random
2222
seed: 9
2323

24+
interference:
25+
batch_size: 1
26+
num_generate: 12
27+
output_path: gen_test
28+
n_frames: 30
29+
# mol_size: [0,0]

configs/interference/gen_outpaint.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ condition_configs:
2323
scale_factor: 1.1
2424
noise_initial_mask: False
2525
connector_dicts:
26-
1: [3]
27-
2: [3]
28-
3: [3]
2926
n_retrys: 3
3027
t_retry: 180
3128

scripts/applications/utils/compute_metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def runner(args):
180180
for xyz in tqdm(xyz_passed, desc="Checking neutrality of molecules", total=len(xyz_passed)):
181181
neutral_mols.append(check_neutrality(xyz))
182182

183-
postbuster_results = run_postbuster(mols, timeout=300)
183+
postbuster_results = run_postbuster(mols, timeout=3000)
184184
if postbuster_results is not None:
185185
num_atoms_list = [mol.GetNumAtoms() for mol in mols]
186186
postbuster_results['num_atoms'] = num_atoms_list
@@ -191,7 +191,8 @@ def runner(args):
191191
'double_bond_flatness', 'internal_energy'
192192
]
193193
postbuster_results['valid_posebuster'] = postbuster_results[posebuster_checks].all(axis=1)
194-
194+
posebuster_checks_connected = posebuster_checks + ['all_atoms_connected']
195+
postbuster_results['valid_posebuster_connected'] = postbuster_results[posebuster_checks_connected].all(axis=1)
195196
if args.output is None:
196197
postbuster_output_path = f"{xyz_dir}/postbuster_metrics.csv"
197198
hist_path = f"{xyz_dir}/postbuster_molecular_size_histogram.png"
@@ -201,6 +202,7 @@ def runner(args):
201202
hist_path = f"{base}_postbuster_molecular_size_histogram.png"
202203

203204
postbuster_results['neutral_molecule'] = neutral_mols
205+
postbuster_results["filename"] = [os.path.basename(xyz) for xyz in xyz_passed]
204206
postbuster_results.to_csv(postbuster_output_path, index=False)
205207

206208
logging.info(f"Molecular size mean: {postbuster_results['num_atoms'].mean():.2f}")
@@ -227,6 +229,7 @@ def runner(args):
227229
logging.info(f"Double Bond Flatness: {postbuster_results['double_bond_flatness'].mean():.2f}")
228230
logging.info(f"Internal Energy: {postbuster_results['internal_energy'].mean():.2f}")
229231
logging.info(f"Valid Posebuster: {postbuster_results['valid_posebuster'].mean() * 100:.2f}%")
232+
logging.info(f"Valid Posebuster Connected: {postbuster_results['valid_posebuster_connected'].mean() * 100:.2f}%")
230233
logging.info(f"Neutral Molecule: {sum(neutral_mols) / len(neutral_mols) * 100:.2f}%")
231234

232235

scripts/applications/utils/xtb_optimization.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tqdm import tqdm
66
import argparse
77
import torch
8+
import pandas as pd
89

910
from MolecularDiffusion.utils import create_pyg_graph, correct_edges
1011
from MolecularDiffusion.utils.geom_utils import read_xyz_file
@@ -159,7 +160,9 @@ def get_xtb_optimized_xyz(
159160
level: str = "gfn1",
160161
timeout: int = 240,
161162
scale_factor: float = 1.3,
162-
optimize_all: bool = True
163+
optimize_all: bool = True,
164+
csv_path: str = None,
165+
filter_column: str = None
163166
) -> list[str]:
164167
"""
165168
Optimizes all XYZ files in a given input directory using xTB and saves them
@@ -180,6 +183,8 @@ def get_xtb_optimized_xyz(
180183
timeout (int, optional): The maximum time in seconds to wait for each xTB process. Defaults to 240.
181184
scale_factor (float, optional): The scaling factor for covalent radii in edge correction. Defaults to 1.3.
182185
optimize_all (bool, optional): If True, optimizes all files regardless of existing optimized versions.
186+
csv_path (str, optional): Path to a CSV file to filter which XYZ files to optimize.
187+
filter_column (str, optional): The column name in the CSV to filter by (values must be 1).
183188
184189
Returns:
185190
list[str]: A list of paths to the successfully optimized XYZ files.
@@ -189,7 +194,49 @@ def get_xtb_optimized_xyz(
189194

190195
os.makedirs(output_directory, exist_ok=True)
191196

192-
xyz_files = glob.glob(os.path.join(input_directory, "*.xyz"))
197+
xyz_files = []
198+
if csv_path:
199+
if not os.path.exists(csv_path):
200+
raise FileNotFoundError(f"CSV file not found: {csv_path}")
201+
202+
df = pd.read_csv(csv_path)
203+
204+
fname_col = None
205+
for col in ["xyz_file", "filename", "filepath"]:
206+
if col in df.columns:
207+
fname_col = col
208+
break
209+
210+
if fname_col is None:
211+
raise ValueError("CSV must contain 'xyz_file', 'filename', or 'filepath' column.")
212+
213+
if filter_column:
214+
if filter_column not in df.columns:
215+
raise ValueError(f"Filter column '{filter_column}' not found in CSV.")
216+
# Filter rows where the value is 1 (as integer or string)
217+
filtered_df = df[df[filter_column].isin(['1', '1.0', True, 1])]
218+
else:
219+
filtered_df = df
220+
221+
for _, row in filtered_df.iterrows():
222+
fname = str(row[fname_col])
223+
# Handle potential missing extension if it's just a name
224+
if not fname.lower().endswith('.xyz'):
225+
fname += '.xyz'
226+
227+
if os.path.isabs(fname):
228+
full_path = fname
229+
else:
230+
full_path = os.path.join(input_directory, fname)
231+
232+
if os.path.exists(full_path):
233+
xyz_files.append(full_path)
234+
else:
235+
print(f"Warning: File from CSV not found: {full_path}")
236+
237+
else:
238+
xyz_files = glob.glob(os.path.join(input_directory, "*.xyz"))
239+
193240
optimized_files = []
194241

195242
for xyz_file in tqdm(xyz_files, desc="Optimizing XYZ files", total=len(xyz_files)):
@@ -260,6 +307,18 @@ def get_xtb_optimized_xyz(
260307
default=1.3,
261308
help="Scaling factor for covalent radii in edge correction. Defaults to 1.3."
262309
)
310+
parser.add_argument(
311+
"--csv_path",
312+
type=str,
313+
default=None,
314+
help="Path to CSV file for filtering which files to optimize."
315+
)
316+
parser.add_argument(
317+
"--filter_column",
318+
type=str,
319+
default=None,
320+
help="Column name in CSV to filter by (values must be 1 to process)."
321+
)
263322

264323
args = parser.parse_args()
265324

@@ -272,7 +331,9 @@ def get_xtb_optimized_xyz(
272331
charge=args.charge,
273332
level=args.level,
274333
timeout=args.timeout,
275-
scale_factor=args.scale_factor
334+
scale_factor=args.scale_factor,
335+
csv_path=args.csv_path,
336+
filter_column=args.filter_column
276337
)
277338

278339
print(f"Successfully optimized {len(optimized_files)} XYZ files and saved them in '{output_dir}'.")

scripts/gradient_guidance/sf_energy_score.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import sys
2+
import os
23
from math import sqrt
34
from typing import List
45

56
import numpy as np
67
import torch
8+
import rootutils
9+
10+
# Setup root directory
11+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
712

813
# from navicatGA.timeout import timeout
914
from numpy import dot
1015
from MolecularDiffusion.core import Engine
16+
from MolecularDiffusion.utils.plot_function import (
17+
plot_kde_distribution,
18+
plot_histogram_distribution,
19+
)
1120
from torch_geometric.data import Data
1221
from torch_geometric.nn import radius_graph
1322
from torch_geometric.data import Batch
@@ -308,3 +317,82 @@ def __call__(self, xh, t):
308317
target = -energy_score(preds[1], preds[0])
309318

310319
return target
320+
321+
if __name__ == "__main__":
322+
import argparse
323+
import pandas as pd
324+
325+
parser = argparse.ArgumentParser(description="Compute energy scores from a CSV file.")
326+
parser.add_argument("--input_csv", type=str, required=True, help="Path to the input CSV file.")
327+
parser.add_argument("--output_csv", type=str, required=True, help="Path to save the output CSV file.")
328+
parser.add_argument("--threshold", type=float, help="Threshold to count entries with score higher than this value.")
329+
args = parser.parse_args()
330+
331+
df = pd.read_csv(args.input_csv)
332+
333+
# Identify columns
334+
s1_col = next((col for col in ["S1", "S1_exc", "s1"] if col in df.columns), None)
335+
t1_col = next((col for col in ["T1", "T1_exc", "t1"] if col in df.columns), None)
336+
337+
if not s1_col or not t1_col:
338+
print(f"Error: Could not find S1 or T1 columns. Available columns: {df.columns}")
339+
sys.exit(1)
340+
341+
energy_scores = []
342+
for index, row in df.iterrows():
343+
try:
344+
s1_val = float(row[s1_col])
345+
t1_val = float(row[t1_col])
346+
347+
# energy_score takes torch tensors
348+
t1_tensor = torch.tensor(t1_val)
349+
s1_tensor = torch.tensor(s1_val)
350+
351+
# energy_score(x, y) where x=t1, y=s1
352+
score = energy_score(t1_tensor, s1_tensor).item()
353+
energy_scores.append(score)
354+
except Exception as e:
355+
print(f"Error processing row {index}: {e}")
356+
energy_scores.append(np.nan)
357+
358+
df["energy_score"] = energy_scores
359+
360+
# Save
361+
df.to_csv(args.output_csv, index=False)
362+
363+
# Statistics
364+
valid_scores = [s for s in energy_scores if not np.isnan(s)]
365+
if valid_scores:
366+
print(f"Mean energy score: {np.mean(valid_scores)}")
367+
print(f"Max energy score: {np.max(valid_scores)}")
368+
print(f"Min energy score: {np.min(valid_scores)}")
369+
370+
if args.threshold is not None:
371+
count_above = sum(1 for s in valid_scores if s > args.threshold)
372+
portion_above = count_above / len(valid_scores)
373+
print(f"Entries with score > {args.threshold}: {count_above} ({portion_above:.2%})")
374+
375+
# Plotting
376+
output_dir = os.path.dirname(args.output_csv)
377+
if not output_dir:
378+
output_dir = "."
379+
380+
print(f"Plotting distributions to {output_dir}")
381+
try:
382+
# Drop NaNs for plotting
383+
plot_series = df["energy_score"].dropna()
384+
385+
plot_kde_distribution(
386+
plot_series,
387+
"Energy Score",
388+
os.path.join(output_dir, "energy_score_kde.png")
389+
)
390+
plot_histogram_distribution(
391+
plot_series,
392+
"Energy Score",
393+
os.path.join(output_dir, "energy_score_hist.png")
394+
)
395+
except Exception as e:
396+
print(f"Error plotting distributions: {e}")
397+
else:
398+
print("No valid energy scores computed.")

src/MolecularDiffusion/data/component/feature.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
import warnings
2-
3-
import torch
42
from itertools import combinations
53

64
import ase
7-
from ase import Atoms, neighborlist
8-
from ase.data import covalent_radii
9-
from ase.data.vdw_alvarez import vdw_radii
10-
from cosymlib import Geometry
115
import networkx as nx
126
import numpy as np
13-
from rdkit import Chem
14-
from rdkit.Chem import AllChem
15-
import numpy as np
167
import scipy.spatial
17-
from ase.io.extxyz import read_xyz
8+
import torch
9+
from ase import Atoms, neighborlist
10+
from ase.data import covalent_radii
11+
from ase.data.vdw_alvarez import vdw_radii
1812
from morfeus import SASA
1913
from networkx.algorithms import community as nx_comm
14+
from rdkit import Chem
15+
from rdkit.Chem import AllChem
2016

2117
from cell2mol.elementdata import ElementData
2218

19+
try:
20+
from cosymlib import Geometry
21+
is_cosymlib_available = True
22+
except ImportError:
23+
is_cosymlib_available = False
24+
Geometry = None
25+
2326
# less than 4 bonds
2427
# 0 for S, 180 for SP, 120 for SP2, 109.5 for SP3
2528
hybridization_dicts = {
@@ -730,6 +733,8 @@ def atom_geom_opt(z, coords, scale_factor = 1.3):
730733

731734
def atom_geom_shape(z, coords, scale_factor = 1.3):
732735

736+
if not(is_cosymlib_available):
737+
raise ImportError("Cosymlib is not available, do use different featurizer")
733738
device = coords.device
734739
N = coords.size(0)
735740

0 commit comments

Comments
 (0)