|
6 | 6 | import os |
7 | 7 | import pickle |
8 | 8 | from collections import defaultdict |
| 9 | +import itertools |
| 10 | +from glob import glob |
9 | 11 |
|
10 | 12 | import numpy as np |
11 | 13 | import torch |
@@ -382,10 +384,25 @@ def load_db( |
382 | 384 | self.graph_data_list = [] |
383 | 385 | self.n_atoms = [] |
384 | 386 |
|
385 | | - db = connect(db_path) |
386 | | - iterator = db.select() |
| 387 | + db_files = [] |
| 388 | + if os.path.isdir(db_path): |
| 389 | + db_files.extend(glob(os.path.join(db_path, "*.db"))) |
| 390 | + elif os.path.isfile(db_path): |
| 391 | + db_files.append(db_path) |
| 392 | + else: |
| 393 | + raise ValueError( |
| 394 | + f"Invalid db_path: {db_path}. It must be a .db file or a directory containing .db files." |
| 395 | + ) |
| 396 | + |
| 397 | + if not db_files: |
| 398 | + raise FileNotFoundError(f"No .db files found in {db_path}") |
| 399 | + |
| 400 | + dbs = [connect(f) for f in db_files] |
| 401 | + total_len = sum(len(db) for db in dbs) |
| 402 | + iterator = itertools.chain.from_iterable(db.select() for db in dbs) |
| 403 | + |
387 | 404 | if verbose: |
388 | | - iterator = tqdm(iterator, "Processing ASE db files", total=len(db)) |
| 405 | + iterator = tqdm(iterator, "Processing ASE db files", total=total_len) |
389 | 406 |
|
390 | 407 | for i, row in enumerate(iterator): |
391 | 408 | try: |
@@ -1252,12 +1269,30 @@ def load_db( |
1252 | 1269 | self.n_atoms = [] |
1253 | 1270 | self.atom_vocab = atom_vocab |
1254 | 1271 |
|
1255 | | - db = connect(db_path) |
| 1272 | + db_files = [] |
| 1273 | + if os.path.isdir(db_path): |
| 1274 | + db_files.extend(glob(os.path.join(db_path, "*.db"))) |
| 1275 | + elif os.path.isfile(db_path): |
| 1276 | + db_files.append(db_path) |
| 1277 | + else: |
| 1278 | + raise ValueError( |
| 1279 | + f"Invalid db_path: {db_path}. It must be a .db file or a directory containing .db files." |
| 1280 | + ) |
| 1281 | + |
| 1282 | + if not db_files: |
| 1283 | + raise FileNotFoundError(f"No .db files found in {db_path}") |
| 1284 | + |
| 1285 | + if verbose: |
| 1286 | + logger.info(f"Found {len(db_files)} .db files to load:") |
| 1287 | + for f_path in db_files: |
| 1288 | + logger.info(f" - {f_path}") |
1256 | 1289 |
|
| 1290 | + dbs = [connect(f) for f in db_files] |
| 1291 | + total_len = sum(len(db) for db in dbs) |
| 1292 | + iterator = itertools.chain.from_iterable(db.select() for db in dbs) |
1257 | 1293 |
|
1258 | | - iterator = db.select() |
1259 | 1294 | if verbose: |
1260 | | - iterator = tqdm(iterator, "Processing ASE db files", total=len(db)) |
| 1295 | + iterator = tqdm(iterator, "Processing ASE db files", total=total_len) |
1261 | 1296 |
|
1262 | 1297 | for i, row in enumerate(iterator): |
1263 | 1298 | try: |
|
0 commit comments