Skip to content

Commit fb9a6e8

Browse files
authored
Allow energy minimization maker to report energies (#1004)
* Make energy minimization reporter report a state file when it runs. This allows us to see energies of minimized configuration. * Only report energy minimization state if state_interval > 0
1 parent 3c0be95 commit fb9a6e8

File tree

6 files changed

+50
-12
lines changed

6 files changed

+50
-12
lines changed

src/atomate2/openmm/jobs/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def make(
225225

226226
# Run the simulation
227227
start = time.time()
228-
self.run_openmm(sim)
228+
self.run_openmm(sim, dir_name)
229229
elapsed_time = time.time() - start
230230

231231
self._update_interchange(interchange, sim, prev_task)
@@ -303,6 +303,7 @@ def _add_reporters(
303303
traj_file_name = self._resolve_attr("traj_file_name", prev_task)
304304
traj_file_type = self._resolve_attr("traj_file_type", prev_task)
305305
report_velocities = self._resolve_attr("report_velocities", prev_task)
306+
wrap_traj = self._resolve_attr("wrap_traj", prev_task)
306307

307308
if has_steps & (traj_interval > 0):
308309
writer_kwargs = {}
@@ -327,7 +328,7 @@ def _add_reporters(
327328
kwargs = dict(
328329
file=str(dir_name / f"{self.traj_file_name}.{traj_file_type}"),
329330
reportInterval=traj_interval,
330-
enforcePeriodicBox=self._resolve_attr("wrap_traj", prev_task),
331+
enforcePeriodicBox=wrap_traj,
331332
)
332333
if report_velocities:
333334
# assert package version
@@ -364,7 +365,7 @@ def _add_reporters(
364365
)
365366
sim.reporters.append(state_reporter)
366367

367-
def run_openmm(self, simulation: Simulation) -> NoReturn:
368+
def run_openmm(self, sim: Simulation, dir_name: Path) -> NoReturn:
368369
"""Abstract method for running the OpenMM simulation.
369370
370371
This method should be implemented by subclasses to

src/atomate2/openmm/jobs/core.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
import numpy as np
99
from openmm import Integrator, LangevinMiddleIntegrator, MonteCarloBarostat
10+
from openmm.app import StateDataReporter
1011
from openmm.unit import atmosphere, kelvin, kilojoules_per_mole, nanometer, picoseconds
1112

1213
from atomate2.openmm.jobs.base import BaseOpenMMMaker
1314
from atomate2.openmm.utils import create_list_summing_to
1415

1516
if TYPE_CHECKING:
17+
from pathlib import Path
18+
1619
from emmet.core.openmm import OpenMMTaskDocument
1720
from openmm.app import Simulation
1821

@@ -41,7 +44,7 @@ class EnergyMinimizationMaker(BaseOpenMMMaker):
4144
tolerance: float = 10
4245
max_iterations: int = 0
4346

44-
def run_openmm(self, sim: Simulation) -> None:
47+
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
4548
"""Run the energy minimization with OpenMM.
4649
4750
This method performs energy minimization on the molecular system using
@@ -62,6 +65,28 @@ def run_openmm(self, sim: Simulation) -> None:
6265
maxIterations=self.max_iterations,
6366
)
6467

68+
if self.state_interval > 0:
69+
state = sim.context.getState(
70+
getPositions=True,
71+
getVelocities=True,
72+
getForces=True,
73+
getEnergy=True,
74+
enforcePeriodicBox=self.wrap_traj,
75+
)
76+
77+
state_reporter = StateDataReporter(
78+
file=f"{dir_name / self.state_file_name}.csv",
79+
reportInterval=0,
80+
step=True,
81+
potentialEnergy=True,
82+
kineticEnergy=True,
83+
totalEnergy=True,
84+
temperature=True,
85+
volume=True,
86+
density=True,
87+
)
88+
state_reporter.report(sim, state)
89+
6590

6691
@dataclass
6792
class NPTMaker(BaseOpenMMMaker):
@@ -87,7 +112,7 @@ class NPTMaker(BaseOpenMMMaker):
87112
pressure: float = 1
88113
pressure_update_frequency: int = 10
89114

90-
def run_openmm(self, sim: Simulation) -> None:
115+
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
91116
"""Evolve the simulation for self.n_steps in the NPT ensemble.
92117
93118
This adds a Monte Carlo barostat to the system to put it into NPT, runs the
@@ -138,7 +163,7 @@ class NVTMaker(BaseOpenMMMaker):
138163
name: str = "nvt simulation"
139164
n_steps: int = 1_000_000
140165

141-
def run_openmm(self, sim: Simulation) -> None:
166+
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
142167
"""Evolve the simulation with OpenMM for self.n_steps.
143168
144169
Parameters
@@ -177,7 +202,7 @@ class TempChangeMaker(BaseOpenMMMaker):
177202
temp_steps: int | None = None
178203
starting_temperature: float | None = None
179204

180-
def run_openmm(self, sim: Simulation) -> None:
205+
def run_openmm(self, sim: Simulation, dir_name: Path) -> None:
181206
"""Evolve the simulation while gradually changing the temperature.
182207
183208
self.temperature is the final temperature. self.temp_steps

tests/openmm_md/flows/test_core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_flow_maker(interchange, run_job):
114114
name="test_production",
115115
tags=["test"],
116116
makers=[
117-
EnergyMinimizationMaker(max_iterations=1),
117+
EnergyMinimizationMaker(max_iterations=1, state_interval=1),
118118
NPTMaker(n_steps=5, pressure=1.0, state_interval=1, traj_interval=1),
119119
OpenMMFlowMaker.anneal_flow(anneal_temp=400, final_temp=300, n_steps=5),
120120
NVTMaker(n_steps=5),
@@ -157,6 +157,15 @@ def test_flow_maker(interchange, run_job):
157157
calc_output = task_doc.calcs_reversed[0].output
158158
assert len(calc_output.steps_reported) == 5
159159

160+
all_steps = [calc.output.steps_reported for calc in task_doc.calcs_reversed]
161+
assert all_steps == [
162+
[11, 12, 13, 14, 15],
163+
[10],
164+
[8, 9],
165+
[6, 7],
166+
[1, 2, 3, 4, 5],
167+
[0],
168+
]
160169
# Test that the state interval is respected
161170
assert calc_output.steps_reported == list(range(11, 16))
162171
assert calc_output.traj_file == "trajectory5.dcd"

tests/openmm_md/jobs/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_make(interchange, tmp_path, run_job):
133133

134134
# monkey patch to allow running the test without openmm
135135

136-
def do_nothing(self, sim):
136+
def do_nothing(self, sim, dir_name):
137137
pass
138138

139139
BaseOpenMMMaker.run_openmm = do_nothing
@@ -170,7 +170,7 @@ def do_nothing(self, sim):
170170

171171
def test_make_w_velocities(interchange, run_job):
172172
# monkey patch to allow running the test without openmm
173-
def do_nothing(self, sim):
173+
def do_nothing(self, sim, dir_name):
174174
pass
175175

176176
BaseOpenMMMaker.run_openmm = do_nothing
@@ -215,7 +215,7 @@ def test_make_from_prev(run_job):
215215
maker = BaseOpenMMMaker(n_steps=10)
216216

217217
# monkey patch to allow running the test without openmm
218-
def do_nothing(self, sim):
218+
def do_nothing(self, sim, dir_name):
219219
pass
220220

221221
BaseOpenMMMaker.run_openmm = do_nothing

tests/openmm_md/jobs/test_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pathlib import Path
2+
13
import numpy as np
24
from emmet.core.openmm import OpenMMInterchange
35
from openmm import XmlSerializer
@@ -23,6 +25,7 @@ def test_energy_minimization_maker(interchange, run_job):
2325
new_positions = new_state.getPositions(asNumpy=True)
2426

2527
assert not np.all(new_positions == start_positions)
28+
assert (Path(task_doc.calcs_reversed[0].output.dir_name) / "state.csv").exists()
2629

2730

2831
def test_npt_maker(interchange, run_job):

tests/openmm_md/jobs/test_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_make_from_prev(openmm_data, run_job):
138138
maker = BaseOpenMMMaker(n_steps=10)
139139

140140
# monkey patch to allow running the test without openmm
141-
def do_nothing(self, sim):
141+
def do_nothing(self, sim, dir_name):
142142
pass
143143

144144
BaseOpenMMMaker.run_openmm = do_nothing

0 commit comments

Comments
 (0)