Skip to content

Commit 0fb73a9

Browse files
orionarcherjanosh
andauthored
Add trajectory reporter to openmm workflow (#1053)
* Add trajectory reporter to openmm workflow * respond to janosh review * fix test * slightly stricter asserts in test_trajectory_reporter --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 071d1c8 commit 0fb73a9

File tree

4 files changed

+285
-29
lines changed

4 files changed

+285
-29
lines changed

src/atomate2/openmm/jobs/base.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from openmm.unit import angstrom, kelvin, picoseconds
2727
from pymatgen.core import Structure
2828

29-
from atomate2.openmm.utils import increment_name, task_reports
29+
from atomate2.openmm.utils import (
30+
PymatgenTrajectoryReporter,
31+
increment_name,
32+
task_reports,
33+
)
3034

3135
if TYPE_CHECKING:
3236
from collections.abc import Callable
@@ -232,16 +236,18 @@ def make(
232236

233237
structure = self._create_structure(sim, prev_task)
234238

235-
task_doc = self._create_task_doc(
236-
interchange, structure, elapsed_time, dir_name, prev_task
237-
)
238-
239239
# leaving the MDAReporter makes the builders fail
240240
for _ in range(len(sim.reporters)):
241241
reporter = sim.reporters.pop()
242+
if hasattr(reporter, "save"):
243+
reporter.save()
242244
del reporter
243245
del sim
244246

247+
task_doc = self._create_task_doc(
248+
interchange, structure, elapsed_time, dir_name, prev_task
249+
)
250+
245251
# write out task_doc json to output dir
246252
with open(dir_name / "taskdoc.json", "w") as file:
247253
json.dump(task_doc.model_dump(), file, cls=MontyEncoder)
@@ -308,7 +314,7 @@ def _add_reporters(
308314
if has_steps & (traj_interval > 0):
309315
writer_kwargs = {}
310316
# these are the only file types that support velocities
311-
if traj_file_type in ["h5md", "nc", "ncdf"]:
317+
if traj_file_type in ("h5md", "nc", "ncdf", "json"):
312318
writer_kwargs["velocities"] = report_velocities
313319
writer_kwargs["forces"] = False
314320
elif report_velocities and traj_file_type != "trr":
@@ -330,17 +336,20 @@ def _add_reporters(
330336
reportInterval=traj_interval,
331337
enforcePeriodicBox=wrap_traj,
332338
)
333-
if report_velocities:
334-
# assert package version
335-
336-
kwargs["writer_kwargs"] = writer_kwargs
337-
warnings.warn(
338-
"Reporting velocities is only supported with the"
339-
"development version of MDAnalysis, >= 2.8.0, "
340-
"proceed with caution.",
341-
stacklevel=1,
342-
)
343-
traj_reporter = MDAReporter(**kwargs)
339+
if traj_file_type == "json":
340+
traj_reporter = PymatgenTrajectoryReporter(**kwargs)
341+
else:
342+
if report_velocities:
343+
# assert package version
344+
345+
kwargs["writer_kwargs"] = writer_kwargs
346+
warnings.warn(
347+
"Reporting velocities is only supported with the"
348+
"development version of MDAnalysis, >= 2.8.0, "
349+
"proceed with caution.",
350+
stacklevel=1,
351+
)
352+
traj_reporter = MDAReporter(**kwargs)
344353

345354
sim.reporters.append(traj_reporter)
346355

src/atomate2/openmm/utils.py

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
import numpy as np
1414
import openmm.unit as omm_unit
1515
from emmet.core.openmm import OpenMMInterchange
16-
from openmm import LangevinMiddleIntegrator, XmlSerializer
17-
from openmm.app import PDBFile
16+
from openmm import LangevinMiddleIntegrator, State, XmlSerializer
17+
from openmm.app import PDBFile, Simulation
18+
from pymatgen.core.trajectory import Trajectory
1819

1920
if TYPE_CHECKING:
2021
from emmet.core.openmm import OpenMMTaskDocument
@@ -73,7 +74,7 @@ def download_opls_xml(
7374
submit_button.click()
7475

7576
# Wait for the second page to load
76-
# time.sleep(2) # Adjust this delay as needed based on the loading time
77+
# time.sleep(2) # Adjust this delay as needed based on loading time
7778

7879
# Find and click the "XML" button under Downloads and OpenMM
7980
xml_button = driver.find_element(
@@ -171,3 +172,145 @@ def openff_to_openmm_interchange(
171172
state=XmlSerializer.serialize(state),
172173
topology=pdb,
173174
)
175+
176+
177+
class PymatgenTrajectoryReporter:
178+
"""Reporter that creates a pymatgen Trajectory from an OpenMM simulation.
179+
180+
Accumulates structures and velocities during the simulation and writes them to a
181+
Trajectory object when the reporter is deleted.
182+
"""
183+
184+
def __init__(
185+
self,
186+
file: str | Path,
187+
reportInterval: int, # noqa: N803
188+
enforcePeriodicBox: bool | None = None, # noqa: N803
189+
) -> None:
190+
"""Initialize the reporter.
191+
192+
Parameters
193+
----------
194+
file : str | Path
195+
The file to write the trajectory to
196+
reportInterval : int
197+
The interval (in time steps) at which to save frames
198+
enforcePeriodicBox : bool | None
199+
Whether to wrap coordinates to the periodic box. If None, determined from
200+
simulation settings.
201+
"""
202+
self._file = file
203+
self._reportInterval = reportInterval
204+
self._enforcePeriodicBox = enforcePeriodicBox
205+
self._topology = None
206+
self._nextModel = 0
207+
208+
# Storage for trajectory data
209+
self._positions: list[np.ndarray] = []
210+
self._velocities: list[np.ndarray] = []
211+
self._lattices: list[np.ndarray] = []
212+
self._frame_properties: list[dict] = []
213+
self._species: list[str] | None = None
214+
self._time_step: float | None = None
215+
216+
def describeNextReport( # noqa: N802
217+
self, simulation: Simulation
218+
) -> tuple[int, bool, bool, bool, bool, bool]:
219+
"""Get information about the next report this object will generate.
220+
221+
Parameters
222+
----------
223+
simulation : Simulation
224+
The Simulation to generate a report for
225+
226+
Returns
227+
-------
228+
tuple[int, bool, bool, bool, bool, bool]
229+
A six element tuple. The first element is the number of steps until the
230+
next report. The remaining elements specify whether that report will
231+
require positions, velocities, forces, energies, and periodic box info.
232+
"""
233+
steps = self._reportInterval - simulation.currentStep % self._reportInterval
234+
return steps, True, True, False, True, self._enforcePeriodicBox
235+
236+
def report(self, simulation: Simulation, state: State) -> None:
237+
"""Generate a report.
238+
239+
Parameters
240+
----------
241+
simulation : Simulation
242+
The Simulation to generate a report for
243+
state : State
244+
The current state of the simulation
245+
"""
246+
if self._nextModel == 0:
247+
self._topology = simulation.topology
248+
self._species = [
249+
atom.element.symbol for atom in simulation.topology.atoms()
250+
]
251+
self._time_step = (
252+
simulation.integrator.getStepSize() * self._reportInterval
253+
).value_in_unit(omm_unit.femtoseconds)
254+
255+
# Get positions and velocities in Angstrom and Angstrom/fs
256+
positions = state.getPositions(asNumpy=True).value_in_unit(omm_unit.angstrom)
257+
velocities = state.getVelocities(asNumpy=True).value_in_unit(
258+
omm_unit.angstrom / omm_unit.femtosecond
259+
)
260+
box_vectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(
261+
omm_unit.angstrom
262+
)
263+
264+
# Get energies in eV
265+
kinetic_energy = (
266+
state.getKineticEnergy() / omm_unit.AVOGADRO_CONSTANT_NA
267+
).value_in_unit(omm_unit.ev)
268+
269+
potential_energy = (
270+
state.getPotentialEnergy() / omm_unit.AVOGADRO_CONSTANT_NA
271+
).value_in_unit(omm_unit.ev)
272+
273+
self._positions.append(positions)
274+
self._velocities.append(velocities)
275+
self._lattices.append(box_vectors)
276+
self._frame_properties.append(
277+
{
278+
"kinetic_energy": kinetic_energy,
279+
"potential_energy": potential_energy,
280+
"total_energy": kinetic_energy + potential_energy,
281+
}
282+
)
283+
284+
self._nextModel += 1
285+
286+
def save(self) -> None:
287+
"""Write accumulated trajectory data to a pymatgen Trajectory object."""
288+
if not self._positions:
289+
return
290+
291+
velocities = [
292+
[tuple(site_vel) for site_vel in frame_vel]
293+
for frame_vel in self._velocities
294+
]
295+
296+
# Format site properties as list of dicts, one per frame
297+
site_properties = []
298+
n_frames = len(self._positions)
299+
site_properties = [{"velocities": velocities[i]} for i in range(n_frames)]
300+
301+
# Create trajectory with positions and lattices
302+
trajectory = Trajectory(
303+
species=self._species,
304+
coords=self._positions,
305+
lattice=self._lattices,
306+
frame_properties=self._frame_properties,
307+
site_properties=site_properties, # Now properly formatted as list of dicts
308+
time_step=self._time_step,
309+
)
310+
311+
# Store trajectory as a class attribute so it can be accessed after deletion
312+
self.trajectory = trajectory
313+
314+
# write out trajectory to a file
315+
with open(self._file, mode="w") as file:
316+
file.write(trajectory.to_json())

tests/openmm_md/jobs/test_core.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from collections.abc import Callable
12
from pathlib import Path
23

34
import numpy as np
4-
from emmet.core.openmm import OpenMMInterchange
5+
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
6+
from monty.serialization import loadfn
57
from openmm import XmlSerializer
68

79
from atomate2.openmm.jobs import (
@@ -12,7 +14,9 @@
1214
)
1315

1416

15-
def test_energy_minimization_maker(interchange, run_job):
17+
def test_energy_minimization_maker(
18+
interchange: OpenMMInterchange, run_job: Callable
19+
) -> None:
1620
state = XmlSerializer.deserialize(interchange.state)
1721
start_positions = state.getPositions(asNumpy=True)
1822

@@ -28,7 +32,7 @@ def test_energy_minimization_maker(interchange, run_job):
2832
assert (Path(task_doc.calcs_reversed[0].output.dir_name) / "state.csv").exists()
2933

3034

31-
def test_npt_maker(interchange, run_job):
35+
def test_npt_maker(interchange: OpenMMInterchange, run_job: Callable) -> None:
3236
state = XmlSerializer.deserialize(interchange.state)
3337
start_positions = state.getPositions(asNumpy=True)
3438
start_box = state.getPeriodicBoxVectors()
@@ -47,11 +51,11 @@ def test_npt_maker(interchange, run_job):
4751
assert not np.all(new_box == start_box)
4852

4953

50-
def test_nvt_maker(interchange, run_job):
54+
def test_nvt_maker(interchange: OpenMMInterchange, run_job: Callable) -> None:
5155
state = XmlSerializer.deserialize(interchange.state)
5256
start_positions = state.getPositions(asNumpy=True)
5357

54-
maker = NVTMaker(n_steps=10, state_interval=1)
58+
maker = NVTMaker(n_steps=10, state_interval=1, traj_interval=5)
5559
base_job = maker.make(interchange)
5660
task_doc = run_job(base_job)
5761

@@ -70,7 +74,7 @@ def test_nvt_maker(interchange, run_job):
7074
assert calc_output.steps_reported == list(range(1, 11))
7175

7276

73-
def test_temp_change_maker(interchange, run_job):
77+
def test_temp_change_maker(interchange: OpenMMInterchange, run_job: Callable):
7478
state = XmlSerializer.deserialize(interchange.state)
7579
start_positions = state.getPositions(asNumpy=True)
7680

@@ -88,3 +92,38 @@ def test_temp_change_maker(interchange, run_job):
8892
# test that temperature was updated correctly in the input
8993
assert task_doc.calcs_reversed[0].input.temperature == 310
9094
assert task_doc.calcs_reversed[0].input.starting_temperature == 298
95+
96+
97+
def test_trajectory_reporter_json(
98+
interchange: OpenMMInterchange, tmp_path: Path, run_job: Callable
99+
):
100+
"""Test that the trajectory reporter can be serialized to JSON."""
101+
# Create simulation using NVTMaker
102+
maker = NVTMaker(
103+
temperature=300,
104+
friction_coefficient=1.0,
105+
step_size=0.002,
106+
platform_name="CPU",
107+
traj_interval=1,
108+
n_steps=3,
109+
traj_file_type="json",
110+
)
111+
112+
job = maker.make(interchange)
113+
task_doc = run_job(job)
114+
115+
# Test serialization/deserialization
116+
json_str = task_doc.model_dump_json()
117+
new_doc = OpenMMTaskDocument.model_validate_json(json_str)
118+
119+
# Verify trajectory data survived the round trip
120+
calc_output = new_doc.calcs_reversed[0].output
121+
traj_file = Path(calc_output.dir_name) / calc_output.traj_file
122+
traj = loadfn(traj_file)
123+
124+
assert len(traj) == 3
125+
assert traj.coords.max() < traj.lattice.max()
126+
assert "kinetic_energy" in traj.frame_properties[0]
127+
128+
# Check that trajectory file was written
129+
assert (tmp_path / "trajectory.json").exists()

0 commit comments

Comments
 (0)