|
19 | 19 | `required_elements`. If adding a new benchmark class, we encourage |
20 | 20 | users to complete this script with a custom function calculating |
21 | 21 | the required element types for their new benchmark. |
| 22 | +
|
| 23 | +Usage: |
| 24 | + uv run scripts/fetch_element_types.py |
22 | 25 | """ |
23 | 26 |
|
24 | 27 | import json |
|
27 | 30 |
|
28 | 31 | from ase import Atoms |
29 | 32 | from ase.io import read as ase_read |
30 | | -from mlipaudit.bond_length_distribution.bond_length_distribution import ( |
| 33 | +from pydantic import BaseModel |
| 34 | +from tqdm import tqdm |
| 35 | + |
| 36 | +from mlipaudit.benchmarks.bond_length_distribution.bond_length_distribution import ( |
31 | 37 | BOND_LENGTH_DISTRIBUTION_DATASET_FILENAME, |
32 | 38 | ) |
33 | | -from mlipaudit.bond_length_distribution.bond_length_distribution import ( |
| 39 | +from mlipaudit.benchmarks.bond_length_distribution.bond_length_distribution import ( |
34 | 40 | Molecules as BLDMolecules, |
35 | 41 | ) |
36 | | -from mlipaudit.conformer_selection.conformer_selection import ( |
| 42 | +from mlipaudit.benchmarks.conformer_selection.conformer_selection import ( |
37 | 43 | WIGGLE_DATASET_FILENAME, |
38 | 44 | Conformers, |
39 | 45 | ) |
40 | | -from mlipaudit.dihedral_scan.dihedral_scan import TORSIONNET_DATASET_FILENAME, Fragments |
41 | | -from mlipaudit.folding_stability.folding_stability import ( |
| 46 | +from mlipaudit.benchmarks.dihedral_scan.dihedral_scan import ( |
| 47 | + TORSIONNET_DATASET_FILENAME, |
| 48 | + Fragments, |
| 49 | +) |
| 50 | +from mlipaudit.benchmarks.folding_stability.folding_stability import ( |
42 | 51 | STRUCTURE_NAMES as FS_STRUCTURE_NAMES, |
43 | 52 | ) |
44 | | -from mlipaudit.noncovalent_interactions.noncovalent_interactions import ( |
| 53 | +from mlipaudit.benchmarks.noncovalent_interactions.noncovalent_interactions import ( |
45 | 54 | NCI_ATLAS_FILENAME, |
46 | 55 | Systems, |
47 | 56 | ) |
48 | | -from mlipaudit.reactivity.reactivity import GRAMBOW_DATASET_FILENAME, Reactions |
49 | | -from mlipaudit.ring_planarity.ring_planarity import RING_PLANARITY_DATASET |
50 | | -from mlipaudit.ring_planarity.ring_planarity import Molecules as RPMolecules |
51 | | -from mlipaudit.sampling.sampling import STRUCTURE_NAMES as SAMPLING_STRUCTURE_NAMES |
52 | | -from mlipaudit.small_molecule_minimization.small_molecule_minimization import ( |
| 57 | +from mlipaudit.benchmarks.reactivity.reactivity import ( |
| 58 | + GRAMBOW_DATASET_FILENAME, |
| 59 | + Reactions, |
| 60 | +) |
| 61 | +from mlipaudit.benchmarks.ring_planarity.ring_planarity import RING_PLANARITY_DATASET |
| 62 | +from mlipaudit.benchmarks.ring_planarity.ring_planarity import Molecules as RPMolecules |
| 63 | +from mlipaudit.benchmarks.sampling.sampling import ( |
| 64 | + STRUCTURE_NAMES as SAMPLING_STRUCTURE_NAMES, |
| 65 | +) |
| 66 | +from mlipaudit.benchmarks.small_molecule_minimization.small_molecule_minimization import ( # noqa: E501 |
53 | 67 | OPENFF_CHARGED_FILENAME, |
54 | 68 | OPENFF_NEUTRAL_FILENAME, |
55 | | - QM9_CHARGED_FILENAME, |
56 | | - QM9_NEUTRAL_FILENAME, |
57 | 69 | ) |
58 | | -from mlipaudit.small_molecule_minimization.small_molecule_minimization import ( |
| 70 | +from mlipaudit.benchmarks.small_molecule_minimization.small_molecule_minimization import ( # noqa: E501 |
59 | 71 | Molecules as SMMMolecules, |
60 | 72 | ) |
61 | | -from mlipaudit.solvent_radial_distribution.solvent_radial_distribution import BOX_CONFIG |
62 | | -from mlipaudit.stability.stability import STRUCTURE_NAMES as STABILITY_STRUCTURE_NAMES |
63 | | -from mlipaudit.stability.stability import STRUCTURES as STABILITY_STRUCTURES |
64 | | -from mlipaudit.tautomers.tautomers import TAUTOMERS_DATASET_FILENAME, TautomerPairs |
65 | | -from mlipaudit.water_radial_distribution.water_radial_distribution import WATERBOX_N500 |
66 | | -from pydantic import BaseModel |
| 73 | +from mlipaudit.benchmarks.solvent_radial_distribution.solvent_radial_distribution import ( # noqa: E501 |
| 74 | + BOX_CONFIG, |
| 75 | +) |
| 76 | +from mlipaudit.benchmarks.stability.stability import ( |
| 77 | + STRUCTURE_NAMES as STABILITY_STRUCTURE_NAMES, |
| 78 | +) |
| 79 | +from mlipaudit.benchmarks.stability.stability import STRUCTURES as STABILITY_STRUCTURES |
| 80 | +from mlipaudit.benchmarks.tautomers.tautomers import ( |
| 81 | + TAUTOMERS_DATASET_FILENAME, |
| 82 | + TautomerPairs, |
| 83 | +) |
| 84 | +from mlipaudit.benchmarks.water_radial_distribution.water_radial_distribution import ( |
| 85 | + WATERBOX_N500, |
| 86 | +) |
67 | 87 |
|
68 | 88 | DATA_LOCATION = "data" |
69 | 89 |
|
@@ -270,8 +290,6 @@ def get_element_types_for_smm(data_dir: os.PathLike | str) -> set[str]: |
270 | 290 | """ |
271 | 291 | atom_element_types = set() |
272 | 292 | for dataset_filename in [ |
273 | | - QM9_NEUTRAL_FILENAME, |
274 | | - QM9_CHARGED_FILENAME, |
275 | 293 | OPENFF_NEUTRAL_FILENAME, |
276 | 294 | OPENFF_CHARGED_FILENAME, |
277 | 295 | ]: |
@@ -364,24 +382,27 @@ def main(): |
364 | 382 | data location, so these data files must be added manually |
365 | 383 | beforehand, either manually or by running the benchmarks. |
366 | 384 | """ |
| 385 | + BENCHMARK_FUNCTIONS = { |
| 386 | + "bld": get_element_types_for_bld, |
| 387 | + "cs": get_element_types_for_cs, |
| 388 | + "ds": get_element_types_for_ds, |
| 389 | + "fs": get_element_types_for_fs, |
| 390 | + "nci": get_element_types_for_nci, |
| 391 | + "r": get_element_types_for_r, |
| 392 | + "rp": get_element_types_for_rp, |
| 393 | + "smm": get_element_types_for_smm, |
| 394 | + "srd": get_element_types_for_srd, |
| 395 | + "sampling": get_element_types_for_sampling, |
| 396 | + "scaling": get_element_types_for_scaling, |
| 397 | + "stability": get_element_types_for_stability, |
| 398 | + "t": get_element_types_for_t, |
| 399 | + "wrd": get_element_types_for_wrd, |
| 400 | + } |
367 | 401 | data_path = Path(__file__).parent.parent / DATA_LOCATION |
368 | 402 |
|
369 | | - element_types_data = { |
370 | | - "bld": list(get_element_types_for_bld(data_path)), |
371 | | - "cs": list(get_element_types_for_cs(data_path)), |
372 | | - "ds": list(get_element_types_for_ds(data_path)), |
373 | | - "fs": list(get_element_types_for_fs(data_path)), |
374 | | - "nci": list(get_element_types_for_nci(data_path)), |
375 | | - "r": list(get_element_types_for_r(data_path)), |
376 | | - "rp": list(get_element_types_for_rp(data_path)), |
377 | | - "smm": list(get_element_types_for_smm(data_path)), |
378 | | - "srd": list(get_element_types_for_srd(data_path)), |
379 | | - "sampling": list(get_element_types_for_sampling(data_path)), |
380 | | - "scaling": list(get_element_types_for_scaling(data_path)), |
381 | | - "stability": list(get_element_types_for_stability(data_path)), |
382 | | - "t": list(get_element_types_for_t(data_path)), |
383 | | - "wrd": list(get_element_types_for_wrd(data_path)), |
384 | | - } |
| 403 | + element_types_data = {} |
| 404 | + for key, func in tqdm(BENCHMARK_FUNCTIONS.items(), desc="Processing Benchmarks"): |
| 405 | + element_types_data[key] = list(func(data_path)) |
385 | 406 |
|
386 | 407 | output_file = "element_types_data.json" |
387 | 408 | with open(output_file, "w", encoding="utf-8") as f: |
|
0 commit comments