Skip to content

Commit 3d3430a

Browse files
feat: add neb benchmark (#57)
* feat: add option for ASE simulation engine to two benchmarks * feat: remove dummy calculator and implement option to load in external model in main.py * feat: adjust all benchmarks to work also with ASE * fix: actually return something from the run_inference() method * test: add unit tests for ase calculator pathway * test: fix bug in tests and add more asserts * docs: update docs with new ase calc pathway * fix: tiny updates related to PR comments * feat: neb engine * feat: run_model() * feat: benchmark implementation * feat: ui * docs: neb * address comments * chore: restore main * chore restore main * add neb to init * fix: simple test * test analyze * linters * linters * chore: linters * feat: ui wrapper * docs: address comments * feat: address comments in benchmark * feat: change inherinance * feat: rename engine file --------- Co-authored-by: Christoph Brunken <[email protected]>
1 parent 325423d commit 3d3430a

File tree

11 files changed

+1438
-0
lines changed

11 files changed

+1438
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.. _nudged_elastic_band_api:
2+
3+
Nudged Elastic Band
4+
===================
5+
6+
.. module:: mlipaudit.benchmarks.nudged_elastic_band.nudged_elastic_band
7+
8+
.. autoclass:: NudgedElasticBandBenchmark
9+
10+
.. automethod:: __init__
11+
12+
.. automethod:: run_model
13+
14+
.. automethod:: analyze
15+
16+
.. autoclass:: NEBResult
17+
18+
.. autoclass:: NEBModelOutput

docs/source/benchmarks/small_molecules/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ sampling, stability, and interactions with other molecules.
1818
Minimization <minimization>
1919
Bond length distribution <bond_length_distribution>
2020
Reactivity <reactivity>
21+
Nudged elastic band <nudged_elastic_band>
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
.. _nudged_elastic_band:
2+
3+
Nudged Elastic Band
4+
===================
5+
6+
Purpose
7+
-------
8+
9+
The nudged elastic band (NEB) is a method to relax a mean energy path between
10+
a reactant and a product structure and thereby find a good guess for the
11+
transition state of the reaction between these two structures. This benchmark assesses
12+
the **MLIP**'s ability to converge NEB calculations where the transition state is already known,
13+
meaning it is a stability benchmark tailored to the NEB method.
14+
15+
16+
Description
17+
-----------
18+
19+
This benchmark uses a custom simulation engine, based on the `ASESimulationEngine` from the `mlip <https://github.com/instadeepai/mlip>`_ library
20+
to run NEB calculations. Before running the NEB calculations, the structures of reactants and products are energy minimized using
21+
the **MLIP** and the **BFGS** optimizer with `alpha=70` and `maxstep=0.03`.
22+
Subsequently, an initial guess for the mean energy path is constructed using the Image Dependent Pair Potential (IDPP),
23+
placing the known transition state structure in the middle with 10 images between the reactant and product structures.
24+
The path is then relaxed using two NEB runs. The first run is a standard NEB calculation with a force convergence threshold of 0.5 eV/Å.
25+
The second run is a NEB calculation with the climbing image method, with a force convergence threshold of 0.05 eV/Å. Both NEB calculations are run for a maximum of 500 steps.
26+
The technical specifications are chosen to resemble those used in the generation of the **Transition1X** \ [#f2]_ dataset.
27+
28+
Dataset
29+
-------
30+
31+
The dataset used for this benchmark is are 100 reactions sampled from the **Grambow** \ [#f1]_ dataset which contains
32+
the reactants, products and transition states of 11960 reactions.
33+
34+
Interpretation
35+
--------------
36+
37+
This benchmarks tests the ability of the model to converge the NEB calculations. The higher the convergence rate, the better.
38+
39+
References
40+
----------
41+
42+
.. [#f1] C. A. Grambow, L. Pattanaik, W. H. Green, Scientific Data 2020. DOI: https://doi.org/10.1038/s41597-020-0460-4
43+
.. [#f2] M. Schreiner, A. Bhowmik, T. Vegge, J. Busk, Ole Winther, Scientific Data 2022. DOI: https://doi.org/10.1038/s41597-022-01870-w.

src/mlipaudit/benchmarks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
NoncovalentInteractionsModelOutput,
3939
NoncovalentInteractionsResult,
4040
)
41+
from mlipaudit.benchmarks.nudged_elastic_band.nudged_elastic_band import (
42+
NEBModelOutput,
43+
NEBResult,
44+
NudgedElasticBandBenchmark,
45+
)
4146
from mlipaudit.benchmarks.reactivity.reactivity import (
4247
ReactivityBenchmark,
4348
ReactivityModelOutput,
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright 2025 InstaDeep Ltd
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import time
17+
from typing import Callable
18+
19+
import ase
20+
import numpy as np
21+
from ase.calculators.calculator import Calculator as ASECalculator
22+
from ase.mep import NEB
23+
from ase.optimize import BFGS
24+
from mlip.models import ForceField
25+
from mlip.simulation import SimulationState
26+
from mlip.simulation.ase.mlip_ase_calculator import MLIPForceFieldASECalculator
27+
from mlip.simulation.configs import ASESimulationConfig
28+
from mlip.simulation.simulation_engine import SimulationEngine
29+
30+
logger = logging.getLogger("mlip")
31+
32+
33+
class NEBSimulationConfig(ASESimulationConfig):
34+
"""Configuration for the NEB simulations.
35+
Also includes the attributes of the parent class
36+
:ASESimulationConfig.
37+
"""
38+
39+
simulation_type: str = "neb"
40+
num_images: int = 7
41+
neb_k: float | None = 10.0
42+
max_force_convergence_threshold: float | None = 0.1
43+
continue_from_previous_run: bool = False
44+
climb: bool = False
45+
46+
47+
class NEBSimulationEngine(SimulationEngine):
48+
"""Simulation engine handling NEB simulations with the ASE backend."""
49+
50+
Config = NEBSimulationConfig
51+
52+
def __init__(
53+
self,
54+
atoms_initial: ase.Atoms,
55+
atoms_final: ase.Atoms,
56+
force_field: ForceField | ASECalculator,
57+
config: NEBSimulationConfig,
58+
images: list[ase.Atoms] | None = None,
59+
transition_state: ase.Atoms | None = None,
60+
) -> None:
61+
"""Initialize the NEB simulation engine."""
62+
self._initialize(
63+
atoms_initial,
64+
atoms_final,
65+
force_field,
66+
config,
67+
images,
68+
transition_state,
69+
)
70+
71+
def _initialize(
72+
self,
73+
atoms_initial: ase.Atoms,
74+
atoms_final: ase.Atoms,
75+
force_field: ForceField | ASECalculator,
76+
config: NEBSimulationConfig,
77+
images: list[ase.Atoms] | None = None,
78+
transition_state: ase.Atoms | None = None,
79+
) -> None:
80+
"""Initialize the NEB simulation."""
81+
self.state = SimulationState()
82+
self.loggers: list[Callable[[SimulationState], None]] = []
83+
84+
self._config = config
85+
self.atoms = atoms_initial
86+
positions = atoms_initial.get_positions()
87+
self._num_atoms = positions.shape[0]
88+
self.state.atomic_numbers = atoms_initial.numbers
89+
self.force_field = force_field
90+
91+
self.model_calculator = self._get_model_calculator()
92+
93+
self.atoms_final = atoms_final
94+
95+
self._init_box_neb(self.atoms)
96+
self._init_box_neb(self.atoms_final)
97+
98+
self.atoms.calc = self._get_model_calculator()
99+
self.atoms_final.calc = self._get_model_calculator()
100+
101+
self.neb = NEB([])
102+
self.images = images
103+
self.transition_state = transition_state
104+
105+
def run(self) -> None:
106+
"""Run the NEB simulation.
107+
108+
Raises:
109+
ValueError: If continue_from_previous_run is True
110+
and images are not provided.
111+
"""
112+
if not self._config.continue_from_previous_run:
113+
self._init_neb()
114+
else:
115+
if not self.images:
116+
raise ValueError(
117+
"Images must be provided if continue_from_previous_run is True"
118+
)
119+
120+
for image in self.images:
121+
image.calc = self._get_model_calculator()
122+
123+
self.neb = NEB(
124+
self.images,
125+
k=self._config.neb_k,
126+
climb=self._config.climb,
127+
parallel=True,
128+
)
129+
130+
dyn = BFGS(self.neb, alpha=70, maxstep=0.03)
131+
132+
def log_to_console() -> None:
133+
"""Logs info to console."""
134+
step = dyn.get_number_of_steps()
135+
compute_time = time.perf_counter() - self.self_start_interval_time
136+
self._log_to_console(step, compute_time)
137+
138+
def set_beginning_interval_time() -> None:
139+
self.self_start_interval_time = time.perf_counter()
140+
141+
def update_state() -> None:
142+
"""Update the internal SimulationState object."""
143+
step = dyn.get_number_of_steps()
144+
compute_time = time.perf_counter() - self.self_start_interval_time
145+
self._update_state_neb(step, compute_time)
146+
147+
dyn.attach(log_to_console, interval=self._config.log_interval)
148+
dyn.attach(self._call_loggers, interval=self._config.log_interval)
149+
dyn.attach(update_state, interval=self._config.snapshot_interval)
150+
dyn.attach(set_beginning_interval_time, interval=self._config.log_interval)
151+
self.self_start_interval_time = time.perf_counter()
152+
153+
dyn.run(
154+
steps=self._config.num_steps,
155+
fmax=self._config.max_force_convergence_threshold,
156+
)
157+
158+
def _init_neb(self) -> None:
159+
if not self.transition_state:
160+
num_images = max(self._config.num_images, 2)
161+
images = [self.atoms]
162+
images.extend([self.atoms.copy() for _ in range(num_images - 2)])
163+
images.append(self.atoms_final)
164+
else:
165+
num_images = max(self._config.num_images, 3)
166+
num_images_1 = num_images // 2 + 1
167+
num_images_2 = num_images - num_images_1 + 1
168+
169+
images_1 = [self.atoms]
170+
images_1.extend([self.atoms.copy() for _ in range(num_images_1 - 2)])
171+
images_1.append(self.transition_state)
172+
173+
images_2 = [self.transition_state.copy()]
174+
images_2.extend([self.atoms_final.copy() for _ in range(num_images_2 - 2)])
175+
images_2.append(self.atoms_final)
176+
177+
for image in images_1:
178+
image.calc = self._get_model_calculator()
179+
for image in images_2:
180+
image.calc = self._get_model_calculator()
181+
182+
neb1 = NEB(
183+
images_1, k=self._config.neb_k, climb=self._config.climb, parallel=True
184+
)
185+
neb2 = NEB(
186+
images_2, k=self._config.neb_k, climb=self._config.climb, parallel=True
187+
)
188+
189+
neb1.interpolate(method="idpp")
190+
neb2.interpolate(method="idpp")
191+
192+
images = neb1.images + neb2.images[1:]
193+
194+
for image in images:
195+
image.calc = self._get_model_calculator()
196+
197+
self.neb = NEB(
198+
images, k=self._config.neb_k, climb=self._config.climb, parallel=True
199+
)
200+
201+
if not self.transition_state:
202+
self.neb.interpolate(method="idpp")
203+
204+
def _init_box_neb(self, atoms: ase.Atoms) -> None:
205+
if isinstance(self._config.box, float):
206+
atoms.cell = np.eye(3) * self._config.box
207+
atoms.pbc = True
208+
elif isinstance(self._config.box, list):
209+
atoms.cell = np.diag(np.array(self._config.box))
210+
atoms.pbc = True
211+
else:
212+
atoms.cell = None
213+
atoms.pbc = False
214+
215+
def _update_state_neb(
216+
self,
217+
step: int,
218+
compute_time: float,
219+
) -> None:
220+
"""Update the internal state of the simulation.
221+
Here, the positions, forces and potential energy for every image
222+
are updated and not concatenated, as for the MD simulations and energy
223+
minimizations.
224+
225+
Args:
226+
step: The current step.
227+
compute_time: The compute time.
228+
"""
229+
self.state.positions = np.zeros((
230+
len(self.neb.images),
231+
len(self.neb.images[0].positions),
232+
3,
233+
))
234+
self.state.potential_energy = np.zeros(len(self.neb.images))
235+
236+
for i, image in enumerate(self.neb.images):
237+
self.state.positions[i] = image.positions
238+
self.state.potential_energy[i] = image.get_potential_energy()
239+
240+
self.state.forces = self.neb.get_forces()
241+
242+
self.state.step = step
243+
self.state.compute_time_seconds += compute_time
244+
245+
def _get_model_calculator(self) -> MLIPForceFieldASECalculator | ASECalculator:
246+
if isinstance(self.force_field, ForceField):
247+
return MLIPForceFieldASECalculator(
248+
self.atoms,
249+
self._config.edge_capacity_multiplier,
250+
self.force_field,
251+
)
252+
else:
253+
return self.force_field
254+
255+
def _call_loggers(self) -> None:
256+
for _logger in self.loggers:
257+
_logger(self.state)
258+
259+
def _log_to_console(self, step: int, compute_time: float) -> None:
260+
"""Logs timing information to console via our logger."""
261+
if step == 0:
262+
logger.debug(
263+
"Initialization took %.2f seconds.",
264+
compute_time,
265+
)
266+
else:
267+
logger.info(
268+
"Steps %s to %s completed in %.2f seconds.",
269+
self.state.step,
270+
step,
271+
compute_time,
272+
)

0 commit comments

Comments
 (0)