Skip to content

Commit d390e8e

Browse files
committed
allow for reading multiple dbs
1 parent 101dc69 commit d390e8e

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

src/MolecularDiffusion/data/component/dataset.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import os
77
import pickle
88
from collections import defaultdict
9+
import itertools
10+
from glob import glob
911

1012
import numpy as np
1113
import torch
@@ -382,10 +384,25 @@ def load_db(
382384
self.graph_data_list = []
383385
self.n_atoms = []
384386

385-
db = connect(db_path)
386-
iterator = db.select()
387+
db_files = []
388+
if os.path.isdir(db_path):
389+
db_files.extend(glob(os.path.join(db_path, "*.db")))
390+
elif os.path.isfile(db_path):
391+
db_files.append(db_path)
392+
else:
393+
raise ValueError(
394+
f"Invalid db_path: {db_path}. It must be a .db file or a directory containing .db files."
395+
)
396+
397+
if not db_files:
398+
raise FileNotFoundError(f"No .db files found in {db_path}")
399+
400+
dbs = [connect(f) for f in db_files]
401+
total_len = sum(len(db) for db in dbs)
402+
iterator = itertools.chain.from_iterable(db.select() for db in dbs)
403+
387404
if verbose:
388-
iterator = tqdm(iterator, "Processing ASE db files", total=len(db))
405+
iterator = tqdm(iterator, "Processing ASE db files", total=total_len)
389406

390407
for i, row in enumerate(iterator):
391408
try:
@@ -1252,12 +1269,30 @@ def load_db(
12521269
self.n_atoms = []
12531270
self.atom_vocab = atom_vocab
12541271

1255-
db = connect(db_path)
1272+
db_files = []
1273+
if os.path.isdir(db_path):
1274+
db_files.extend(glob(os.path.join(db_path, "*.db")))
1275+
elif os.path.isfile(db_path):
1276+
db_files.append(db_path)
1277+
else:
1278+
raise ValueError(
1279+
f"Invalid db_path: {db_path}. It must be a .db file or a directory containing .db files."
1280+
)
1281+
1282+
if not db_files:
1283+
raise FileNotFoundError(f"No .db files found in {db_path}")
1284+
1285+
if verbose:
1286+
logger.info(f"Found {len(db_files)} .db files to load:")
1287+
for f_path in db_files:
1288+
logger.info(f" - {f_path}")
12561289

1290+
dbs = [connect(f) for f in db_files]
1291+
total_len = sum(len(db) for db in dbs)
1292+
iterator = itertools.chain.from_iterable(db.select() for db in dbs)
12571293

1258-
iterator = db.select()
12591294
if verbose:
1260-
iterator = tqdm(iterator, "Processing ASE db files", total=len(db))
1295+
iterator = tqdm(iterator, "Processing ASE db files", total=total_len)
12611296

12621297
for i, row in enumerate(iterator):
12631298
try:

src/MolecularDiffusion/runmodes/train/eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,9 @@ def _validate_xyzs(path_save: str, logger: str, use_posebuster: bool = False, po
343343
summary = {k: v.mean().item() for k, v in metrics.items()}
344344

345345
if use_posebuster:
346-
mols = load_molecules_from_xyz(path_save)
346+
mols, _ = load_molecules_from_xyz(path_save)
347347
if mols:
348348
postbuster_results = run_postbuster(mols, timeout=postbuster_timeout)
349-
print(postbuster_results)
350349
if postbuster_results is not None:
351350
postbuster_output_path = os.path.join(path_save, "postbuster_metrics.csv")
352351
postbuster_results.to_csv(postbuster_output_path, index=False)

0 commit comments

Comments
 (0)