Skip to content

Commit 1e8a47d

Browse files
authored
feat: update elements scripts (#67)
* feat: update imports * feat: add tqdm * feat: update required elements lists
1 parent 39ddce9 commit 1e8a47d

File tree

4 files changed

+62
-41
lines changed

4 files changed

+62
-41
lines changed

scripts/fetch_element_types.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
`required_elements`. If adding a new benchmark class, we encourage
2020
users to complete this script with a custom function calculating
2121
the required element types for their new benchmark.
22+
23+
Usage:
24+
uv run scripts/fetch_element_types.py
2225
"""
2326

2427
import json
@@ -27,43 +30,60 @@
2730

2831
from ase import Atoms
2932
from ase.io import read as ase_read
30-
from mlipaudit.bond_length_distribution.bond_length_distribution import (
33+
from pydantic import BaseModel
34+
from tqdm import tqdm
35+
36+
from mlipaudit.benchmarks.bond_length_distribution.bond_length_distribution import (
3137
BOND_LENGTH_DISTRIBUTION_DATASET_FILENAME,
3238
)
33-
from mlipaudit.bond_length_distribution.bond_length_distribution import (
39+
from mlipaudit.benchmarks.bond_length_distribution.bond_length_distribution import (
3440
Molecules as BLDMolecules,
3541
)
36-
from mlipaudit.conformer_selection.conformer_selection import (
42+
from mlipaudit.benchmarks.conformer_selection.conformer_selection import (
3743
WIGGLE_DATASET_FILENAME,
3844
Conformers,
3945
)
40-
from mlipaudit.dihedral_scan.dihedral_scan import TORSIONNET_DATASET_FILENAME, Fragments
41-
from mlipaudit.folding_stability.folding_stability import (
46+
from mlipaudit.benchmarks.dihedral_scan.dihedral_scan import (
47+
TORSIONNET_DATASET_FILENAME,
48+
Fragments,
49+
)
50+
from mlipaudit.benchmarks.folding_stability.folding_stability import (
4251
STRUCTURE_NAMES as FS_STRUCTURE_NAMES,
4352
)
44-
from mlipaudit.noncovalent_interactions.noncovalent_interactions import (
53+
from mlipaudit.benchmarks.noncovalent_interactions.noncovalent_interactions import (
4554
NCI_ATLAS_FILENAME,
4655
Systems,
4756
)
48-
from mlipaudit.reactivity.reactivity import GRAMBOW_DATASET_FILENAME, Reactions
49-
from mlipaudit.ring_planarity.ring_planarity import RING_PLANARITY_DATASET
50-
from mlipaudit.ring_planarity.ring_planarity import Molecules as RPMolecules
51-
from mlipaudit.sampling.sampling import STRUCTURE_NAMES as SAMPLING_STRUCTURE_NAMES
52-
from mlipaudit.small_molecule_minimization.small_molecule_minimization import (
57+
from mlipaudit.benchmarks.reactivity.reactivity import (
58+
GRAMBOW_DATASET_FILENAME,
59+
Reactions,
60+
)
61+
from mlipaudit.benchmarks.ring_planarity.ring_planarity import RING_PLANARITY_DATASET
62+
from mlipaudit.benchmarks.ring_planarity.ring_planarity import Molecules as RPMolecules
63+
from mlipaudit.benchmarks.sampling.sampling import (
64+
STRUCTURE_NAMES as SAMPLING_STRUCTURE_NAMES,
65+
)
66+
from mlipaudit.benchmarks.small_molecule_minimization.small_molecule_minimization import ( # noqa: E501
5367
OPENFF_CHARGED_FILENAME,
5468
OPENFF_NEUTRAL_FILENAME,
55-
QM9_CHARGED_FILENAME,
56-
QM9_NEUTRAL_FILENAME,
5769
)
58-
from mlipaudit.small_molecule_minimization.small_molecule_minimization import (
70+
from mlipaudit.benchmarks.small_molecule_minimization.small_molecule_minimization import ( # noqa: E501
5971
Molecules as SMMMolecules,
6072
)
61-
from mlipaudit.solvent_radial_distribution.solvent_radial_distribution import BOX_CONFIG
62-
from mlipaudit.stability.stability import STRUCTURE_NAMES as STABILITY_STRUCTURE_NAMES
63-
from mlipaudit.stability.stability import STRUCTURES as STABILITY_STRUCTURES
64-
from mlipaudit.tautomers.tautomers import TAUTOMERS_DATASET_FILENAME, TautomerPairs
65-
from mlipaudit.water_radial_distribution.water_radial_distribution import WATERBOX_N500
66-
from pydantic import BaseModel
73+
from mlipaudit.benchmarks.solvent_radial_distribution.solvent_radial_distribution import ( # noqa: E501
74+
BOX_CONFIG,
75+
)
76+
from mlipaudit.benchmarks.stability.stability import (
77+
STRUCTURE_NAMES as STABILITY_STRUCTURE_NAMES,
78+
)
79+
from mlipaudit.benchmarks.stability.stability import STRUCTURES as STABILITY_STRUCTURES
80+
from mlipaudit.benchmarks.tautomers.tautomers import (
81+
TAUTOMERS_DATASET_FILENAME,
82+
TautomerPairs,
83+
)
84+
from mlipaudit.benchmarks.water_radial_distribution.water_radial_distribution import (
85+
WATERBOX_N500,
86+
)
6787

6888
DATA_LOCATION = "data"
6989

@@ -270,8 +290,6 @@ def get_element_types_for_smm(data_dir: os.PathLike | str) -> set[str]:
270290
"""
271291
atom_element_types = set()
272292
for dataset_filename in [
273-
QM9_NEUTRAL_FILENAME,
274-
QM9_CHARGED_FILENAME,
275293
OPENFF_NEUTRAL_FILENAME,
276294
OPENFF_CHARGED_FILENAME,
277295
]:
@@ -364,24 +382,27 @@ def main():
364382
data location, so these data files must be added manually
365383
beforehand, either manually or by running the benchmarks.
366384
"""
385+
BENCHMARK_FUNCTIONS = {
386+
"bld": get_element_types_for_bld,
387+
"cs": get_element_types_for_cs,
388+
"ds": get_element_types_for_ds,
389+
"fs": get_element_types_for_fs,
390+
"nci": get_element_types_for_nci,
391+
"r": get_element_types_for_r,
392+
"rp": get_element_types_for_rp,
393+
"smm": get_element_types_for_smm,
394+
"srd": get_element_types_for_srd,
395+
"sampling": get_element_types_for_sampling,
396+
"scaling": get_element_types_for_scaling,
397+
"stability": get_element_types_for_stability,
398+
"t": get_element_types_for_t,
399+
"wrd": get_element_types_for_wrd,
400+
}
367401
data_path = Path(__file__).parent.parent / DATA_LOCATION
368402

369-
element_types_data = {
370-
"bld": list(get_element_types_for_bld(data_path)),
371-
"cs": list(get_element_types_for_cs(data_path)),
372-
"ds": list(get_element_types_for_ds(data_path)),
373-
"fs": list(get_element_types_for_fs(data_path)),
374-
"nci": list(get_element_types_for_nci(data_path)),
375-
"r": list(get_element_types_for_r(data_path)),
376-
"rp": list(get_element_types_for_rp(data_path)),
377-
"smm": list(get_element_types_for_smm(data_path)),
378-
"srd": list(get_element_types_for_srd(data_path)),
379-
"sampling": list(get_element_types_for_sampling(data_path)),
380-
"scaling": list(get_element_types_for_scaling(data_path)),
381-
"stability": list(get_element_types_for_stability(data_path)),
382-
"t": list(get_element_types_for_t(data_path)),
383-
"wrd": list(get_element_types_for_wrd(data_path)),
384-
}
403+
element_types_data = {}
404+
for key, func in tqdm(BENCHMARK_FUNCTIONS.items(), desc="Processing Benchmarks"):
405+
element_types_data[key] = list(func(data_path))
385406

386407
output_file = "element_types_data.json"
387408
with open(output_file, "w", encoding="utf-8") as f:

src/mlipaudit/benchmarks/ring_planarity/ring_planarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class RingPlanarityBenchmark(Benchmark):
178178
result_class = RingPlanarityResult
179179
model_output_class = RingPlanarityModelOutput
180180

181-
required_elements = {"H", "C", "O", "N"}
181+
required_elements = {"H", "C", "O", "N", "F"}
182182

183183
def run_model(self) -> None:
184184
"""Run an MD simulation for each structure.

src/mlipaudit/benchmarks/sampling/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class SamplingBenchmark(Benchmark):
285285
result_class = SamplingResult
286286
model_output_class = SamplingModelOutput
287287

288-
required_elements = {"N", "Cl", "H", "O", "S", "F", "P", "C", "Br"}
288+
required_elements = {"N", "H", "O", "S", "C"}
289289

290290
def run_model(self) -> None:
291291
"""Run an MD simulation for each system."""

src/mlipaudit/benchmarks/stability/stability.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class StabilityBenchmark(Benchmark):
383383
result_class = StabilityResult
384384
model_output_class = StabilityModelOutput
385385

386-
required_elements = {"N", "H", "O", "S", "P", "C", "Cl", "F"}
386+
required_elements = {"N", "H", "O", "S", "C", "Cl", "F"}
387387

388388
def run_model(self) -> None:
389389
"""Run MD for each structure.

0 commit comments

Comments
 (0)