Skip to content

Commit a582242

Browse files
authored
feat: three options for how fast benchmarks can run (#42)
* refactor: use run_mode instead of fast_dev_run flag * feat: add changes to some benchmarks to support RunMode.FAST flag * feat: allow for run mode to be passed as string, too * test: add unit test for run mode conversion * fix: path was wrong in tests for data dir
1 parent 0ae2a39 commit a582242

File tree

36 files changed

+267
-93
lines changed

36 files changed

+267
-93
lines changed

docs/source/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Base classes and utilities
1313

1414
benchmark
1515
io
16+
run_mode
1617
utils/trajectory_helpers
1718

1819
Benchmark implementations
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. _run_mode:
2+
3+
.. module:: mlipaudit.run_mode
4+
5+
Run Mode
6+
========
7+
8+
.. autoclass:: RunMode

docs/source/tutorials/cli/index.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ The tool has the following command line options:
3939
list of benchmark names (e.g., ``dihedral_scan``, ``ring_planarity``) or ``all`` to
4040
run all available benchmarks which is also the default which means that if this flag
4141
is not used, all benchmarks will be run.
42-
* ``--fast-dev-run``: *Optional* setting that allows to run a very minimal version of
43-
each benchmark for development and testing purposes. The default behavior is that it
44-
is not set.
42+
* ``--run-mode``: *Optional* setting that allows to run faster versions of the
43+
benchmark suite. The default option ``standard`` which runs the entire suite.
44+
The option ``fast`` runs a slightly faster version for some of the very long-running
45+
benchmarks. The option ``dev`` runs a very minimal version of each benchmark for
46+
development and testing purposes.
4547

4648
For example, if you want to run the entire benchmark suite for two models, say
4749
``visnet_1`` and ``mace_2``, use this command:

docs/source/tutorials/new_benchmark/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ members to override are:
7777
Hence, to add a new benchmark, three classes must be implemented, the benchmark, model
7878
output, and results class.
7979

80+
Note that we also recommend that a new benchmarks implements a very minimal version
81+
of itself that is run if ``self.run_mode == RunMode.DEV``. For very long-running
82+
benchmarks, we also recommend to implement a version for
83+
``self.run_mode == RunMode.FAST`` that may differ
84+
from ``self.run_mode == RunMode.STANDARD``, however, for most benchmarks this may
85+
not be necessary.
86+
8087
Minimal example implementation
8188
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8289

src/mlipaudit/benchmark.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
import zipfile
1717
from abc import ABC, abstractmethod
1818
from pathlib import Path
19-
from typing import Any
19+
from typing import Any, Literal, TypeAlias
2020

2121
from ase import Atom
2222
from huggingface_hub import hf_hub_download
2323
from mlip.models import ForceField
2424
from pydantic import BaseModel
2525

2626
from mlipaudit.exceptions import ChemicalElementsMissingError
27+
from mlipaudit.run_mode import RunMode
28+
29+
RunModeAsString: TypeAlias = Literal["dev", "fast", "standard"]
2730

2831

2932
class BenchmarkResult(BaseModel):
@@ -75,7 +78,7 @@ def __init__(
7578
self,
7679
force_field: ForceField,
7780
data_input_dir: str | os.PathLike = "./data",
78-
fast_dev_run: bool = False,
81+
run_mode: RunMode | RunModeAsString = RunMode.STANDARD,
7982
) -> None:
8083
"""Initializes the benchmark.
8184
@@ -85,19 +88,26 @@ def __init__(
8588
"./data". If the subdirectory "{data_input_dir}/{benchmark_name}"
8689
exists, the benchmark expects the relevant data to be in there,
8790
otherwise it will download it from HuggingFace.
88-
fast_dev_run: Whether to do a fast developer run. Subclasses
89-
should ensure that when `True`, their benchmark runs in a
91+
run_mode: Whether to run the standard benchmark length, a faster version,
92+
or a very fast development version. Subclasses
93+
should ensure that when `RunMode.DEV`, their benchmark runs in a
9094
much shorter timeframe, by running on a reduced number of
91-
test cases, for instance.
95+
test cases, for instance. Implementing `RunMode.FAST` being different
96+
from `RunMode.STANDARD` is optional and only recommended for very
97+
long-running benchmarks. This argument can also be passed as a string
98+
"dev", "fast", or "standard".
9299
93100
Raises:
94101
ChemicalElementsMissingError: If initialization is attempted
95102
with a force field that cannot perform inference on the
96103
required elements.
97104
"""
105+
self.run_mode = run_mode
106+
if not isinstance(self.run_mode, RunMode):
107+
self.run_mode = RunMode(run_mode)
108+
98109
self.force_field = force_field
99110
self._handle_missing_element_types()
100-
self.fast_dev_run = fast_dev_run
101111
self.data_input_dir = Path(data_input_dir)
102112

103113
self.model_output: ModelOutput | None = None

src/mlipaudit/bond_length_distribution/bond_length_distribution.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic import BaseModel, ConfigDict, TypeAdapter
2323

2424
from mlipaudit.benchmark import Benchmark, BenchmarkResult, ModelOutput
25+
from mlipaudit.run_mode import RunMode
2526

2627
logger = logging.getLogger("mlipaudit")
2728

@@ -152,7 +153,7 @@ def run_model(self) -> None:
152153
"""
153154
molecule_outputs = []
154155

155-
if self.fast_dev_run:
156+
if self.run_mode == RunMode.DEV:
156157
md_config = JaxMDSimulationEngine.Config(**SIMULATION_CONFIG_FAST)
157158
else:
158159
md_config = JaxMDSimulationEngine.Config(**SIMULATION_CONFIG)
@@ -231,7 +232,7 @@ def _bond_length_distribution_data(self) -> dict[str, Molecule]:
231232
) as f:
232233
dataset = Molecules.validate_json(f.read())
233234

234-
if self.fast_dev_run:
235+
if self.run_mode == RunMode.DEV:
235236
dataset = dict(list(dataset.items())[:2])
236237

237238
return dataset

src/mlipaudit/conformer_selection/conformer_selection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
2525

2626
from mlipaudit.benchmark import Benchmark, BenchmarkResult, ModelOutput
27+
from mlipaudit.run_mode import RunMode
2728

2829
logger = logging.getLogger("mlipaudit")
2930

@@ -255,7 +256,7 @@ def _wiggle150_data(self) -> list[Conformer]:
255256
) as f:
256257
wiggle150_data = Conformers.validate_json(f.read())
257258

258-
if self.fast_dev_run:
259+
if self.run_mode == RunMode.DEV:
259260
wiggle150_data = wiggle150_data[:1]
260261

261262
return wiggle150_data

src/mlipaudit/dihedral_scan/dihedral_scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
2626

2727
from mlipaudit.benchmark import Benchmark, BenchmarkResult, ModelOutput
28+
from mlipaudit.run_mode import RunMode
2829

2930
logger = logging.getLogger("mlipaudit")
3031

@@ -287,7 +288,7 @@ def _torsion_net_500(self) -> dict[str, Fragment]:
287288
) as f:
288289
dataset = Fragments.validate_json(f.read())
289290

290-
if self.fast_dev_run:
291+
if self.run_mode == RunMode.DEV:
291292
dataset = {
292293
"fragment_001": dataset["fragment_001"],
293294
"fragment_002": dataset["fragment_002"],

src/mlipaudit/folding_stability/folding_stability.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
compute_tm_scores_and_rmsd_values,
2828
get_match_secondary_structure,
2929
)
30+
from mlipaudit.run_mode import RunMode
3031
from mlipaudit.utils import (
3132
create_ase_trajectory_from_simulation_state,
3233
create_mdtraj_trajectory_from_simulation_state,
@@ -161,8 +162,10 @@ def run_model(self) -> None:
161162
simulation_states=[],
162163
)
163164

164-
structure_names = STRUCTURE_NAMES[:1] if self.fast_dev_run else STRUCTURE_NAMES
165-
if self.fast_dev_run:
165+
structure_names = (
166+
STRUCTURE_NAMES[:1] if self.run_mode == RunMode.DEV else STRUCTURE_NAMES
167+
)
168+
if self.run_mode == RunMode.DEV:
166169
md_config = JaxMDSimulationEngine.Config(**SIMULATION_CONFIG_FAST)
167170
else:
168171
md_config = JaxMDSimulationEngine.Config(**SIMULATION_CONFIG)

src/mlipaudit/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from mlipaudit.noncovalent_interactions import NoncovalentInteractionsBenchmark
3030
from mlipaudit.reactivity import ReactivityBenchmark
3131
from mlipaudit.ring_planarity import RingPlanarityBenchmark
32+
from mlipaudit.run_mode import RunMode
3233
from mlipaudit.sampling import SamplingBenchmark
3334
from mlipaudit.scaling import ScalingBenchmark
3435
from mlipaudit.small_molecule_minimization import SmallMoleculeMinimizationBenchmark
@@ -89,14 +90,15 @@ def _parser() -> ArgumentParser:
8990
help="List of benchmarks to run.",
9091
)
9192
parser.add_argument(
92-
"--fast-dev-run",
93-
action="store_true",
94-
help="run the benchmarks in fast-dev-run mode",
93+
"--run-mode",
94+
required=False,
95+
choices=[mode.value for mode in RunMode],
96+
default=RunMode.STANDARD.value,
97+
help="mode to run the benchmarks in",
9598
)
9699
return parser
97100

98101

99-
# TODO: We should probably handle this in a different (nicer) way
100102
def _model_class_from_name(model_name: str) -> type[MLIPNetwork]:
101103
if "visnet" in model_name:
102104
return Visnet
@@ -170,7 +172,7 @@ def main():
170172
benchmark = benchmark_class(
171173
force_field=force_field,
172174
data_input_dir=args.input,
173-
fast_dev_run=args.fast_dev_run,
175+
run_mode=args.run_mode,
174176
)
175177
benchmark.run_model()
176178
result = benchmark.analyze()

0 commit comments

Comments
 (0)