Skip to content

Commit 8e036b9

Browse files
authored
feat: use ASE for all minimizations (#73)
1 parent 3d3430a commit 8e036b9

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

src/mlipaudit/benchmarks/small_molecule_minimization/small_molecule_minimization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,14 @@
5151
"simulation_type": "minimization",
5252
"num_steps": 1000,
5353
"snapshot_interval": 10,
54-
"num_episodes": 10,
55-
"timestep_fs": 0.1,
54+
"max_force_convergence_threshold": 0.01,
5655
}
5756

5857
SIMULATION_CONFIG_FAST = {
5958
"simulation_type": "minimization",
6059
"num_steps": 10,
6160
"snapshot_interval": 1,
62-
"num_episodes": 1,
63-
"timestep_fs": 0.1,
61+
"max_force_convergence_threshold": 0.01,
6462
}
6563

6664
RMSD_SCORE_THRESHOLD = 0.075

src/mlipaudit/utils/simulation.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727

2828
logger = logging.getLogger("mlipaudit")
2929

30-
DEFAULT_ASE_MAX_FORCE_CONV_THRESH = 0.01
31-
3230

3331
class ASESimulationEngineWithCalculator(ASESimulationEngine):
3432
"""Class derived from mlip's ASE simulation engine but allowing for a passed
@@ -73,9 +71,15 @@ def __init__(
7371

7472
def get_simulation_engine(
7573
atoms: ase.Atoms, force_field: ForceField | ASECalculator, **kwargs
76-
) -> JaxMDSimulationEngine | ASESimulationEngineWithCalculator:
74+
) -> JaxMDSimulationEngine | ASESimulationEngineWithCalculator | ASESimulationEngine:
7775
"""Returns the correct simulation engine based on the input force field type.
7876
77+
For MD simulations with `mlip.models.ForceField` objects, we return a
78+
`JaxMDSimulationEngine`. For energy minimizations with those objects, we return
79+
a `ASESimulationEngine`. For any type of simulations with ASE calculator objects,
80+
we return a `ASESimulationEngineWithCalculator`, which is a custom class of
81+
the `mlipaudit` library.
82+
7983
Args:
8084
atoms: The ASE atoms.
8185
force_field: The force field, either an `mlip.models.ForceField`
@@ -89,21 +93,26 @@ def get_simulation_engine(
8993
Raises:
9094
ValueError: If force field type is not compatible.
9195
"""
92-
if isinstance(force_field, ForceField):
96+
# Case 1: MD simulations with ForceField objects -> use JAX-MD
97+
if (
98+
isinstance(force_field, ForceField)
99+
and kwargs.get("simulation_type", "md") == "md"
100+
):
93101
md_config = JaxMDSimulationEngine.Config(**kwargs)
94102
return JaxMDSimulationEngine(atoms, force_field, md_config)
95103

96-
elif isinstance(force_field, ASECalculator):
97-
kwargs_copy = deepcopy(kwargs)
98-
kwargs_copy.pop("num_episodes", None) # remove this if exists
104+
kwargs_copy = deepcopy(kwargs)
105+
kwargs_copy.pop("num_episodes", None) # remove this if exists
99106

100-
# for minimization:
101-
kwargs_copy["max_force_convergence_threshold"] = (
102-
DEFAULT_ASE_MAX_FORCE_CONV_THRESH
103-
)
107+
# Case 2: Minimization with ForceField objects -> use ASE
108+
if isinstance(force_field, ForceField):
109+
minimization_config = ASESimulationEngine.Config(**kwargs_copy)
110+
return ASESimulationEngine(atoms, force_field, minimization_config)
104111

105-
md_config = ASESimulationEngine.Config(**kwargs_copy)
106-
return ASESimulationEngineWithCalculator(atoms, force_field, md_config)
112+
# Case 3: MD or minimization with ASECalculator objects -> use ASE
113+
if isinstance(force_field, ASECalculator):
114+
sim_config = ASESimulationEngine.Config(**kwargs_copy)
115+
return ASESimulationEngineWithCalculator(atoms, force_field, sim_config)
107116

108117
raise ValueError(
109118
"Provided force field must be either a mlip-compatible "

tests/small_molecule_minimization/test_minimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def _gen_sim_output(mol_name: str, num_atoms: int) -> MoleculeSimulationOutput:
7373
"small_mol_minimization_benchmark", [True, False], indirect=True
7474
)
7575
def test_full_run_with_mocked_engine(
76-
small_mol_minimization_benchmark, mock_jaxmd_simulation_engine
76+
small_mol_minimization_benchmark, mock_ase_simulation_engine
7777
):
7878
"""Integration test testing a full run of the benchmark."""
7979
benchmark = small_mol_minimization_benchmark
80-
mock_engine = mock_jaxmd_simulation_engine()
80+
mock_engine = mock_ase_simulation_engine()
8181
with patch(
82-
"mlipaudit.utils.simulation.JaxMDSimulationEngine",
82+
"mlipaudit.utils.simulation.ASESimulationEngine",
8383
return_value=mock_engine,
8484
) as mock_engine_class:
8585
benchmark.run_model()

0 commit comments

Comments
 (0)