|
13 | 13 | import numpy as np
|
14 | 14 | import openmm.unit as omm_unit
|
15 | 15 | from emmet.core.openmm import OpenMMInterchange
|
16 |
| -from openmm import LangevinMiddleIntegrator, XmlSerializer |
17 |
| -from openmm.app import PDBFile |
| 16 | +from openmm import LangevinMiddleIntegrator, State, XmlSerializer |
| 17 | +from openmm.app import PDBFile, Simulation |
| 18 | +from pymatgen.core.trajectory import Trajectory |
18 | 19 |
|
19 | 20 | if TYPE_CHECKING:
|
20 | 21 | from emmet.core.openmm import OpenMMTaskDocument
|
@@ -73,7 +74,7 @@ def download_opls_xml(
|
73 | 74 | submit_button.click()
|
74 | 75 |
|
75 | 76 | # Wait for the second page to load
|
76 |
| - # time.sleep(2) # Adjust this delay as needed based on the loading time |
| 77 | + # time.sleep(2) # Adjust this delay as needed based on loading time |
77 | 78 |
|
78 | 79 | # Find and click the "XML" button under Downloads and OpenMM
|
79 | 80 | xml_button = driver.find_element(
|
@@ -171,3 +172,145 @@ def openff_to_openmm_interchange(
|
171 | 172 | state=XmlSerializer.serialize(state),
|
172 | 173 | topology=pdb,
|
173 | 174 | )
|
| 175 | + |
| 176 | + |
| 177 | +class PymatgenTrajectoryReporter: |
| 178 | + """Reporter that creates a pymatgen Trajectory from an OpenMM simulation. |
| 179 | +
|
| 180 | + Accumulates structures and velocities during the simulation and writes them to a |
| 181 | + Trajectory object when the reporter is deleted. |
| 182 | + """ |
| 183 | + |
| 184 | + def __init__( |
| 185 | + self, |
| 186 | + file: str | Path, |
| 187 | + reportInterval: int, # noqa: N803 |
| 188 | + enforcePeriodicBox: bool | None = None, # noqa: N803 |
| 189 | + ) -> None: |
| 190 | + """Initialize the reporter. |
| 191 | +
|
| 192 | + Parameters |
| 193 | + ---------- |
| 194 | + file : str | Path |
| 195 | + The file to write the trajectory to |
| 196 | + reportInterval : int |
| 197 | + The interval (in time steps) at which to save frames |
| 198 | + enforcePeriodicBox : bool | None |
| 199 | + Whether to wrap coordinates to the periodic box. If None, determined from |
| 200 | + simulation settings. |
| 201 | + """ |
| 202 | + self._file = file |
| 203 | + self._reportInterval = reportInterval |
| 204 | + self._enforcePeriodicBox = enforcePeriodicBox |
| 205 | + self._topology = None |
| 206 | + self._nextModel = 0 |
| 207 | + |
| 208 | + # Storage for trajectory data |
| 209 | + self._positions: list[np.ndarray] = [] |
| 210 | + self._velocities: list[np.ndarray] = [] |
| 211 | + self._lattices: list[np.ndarray] = [] |
| 212 | + self._frame_properties: list[dict] = [] |
| 213 | + self._species: list[str] | None = None |
| 214 | + self._time_step: float | None = None |
| 215 | + |
| 216 | + def describeNextReport( # noqa: N802 |
| 217 | + self, simulation: Simulation |
| 218 | + ) -> tuple[int, bool, bool, bool, bool, bool]: |
| 219 | + """Get information about the next report this object will generate. |
| 220 | +
|
| 221 | + Parameters |
| 222 | + ---------- |
| 223 | + simulation : Simulation |
| 224 | + The Simulation to generate a report for |
| 225 | +
|
| 226 | + Returns |
| 227 | + ------- |
| 228 | + tuple[int, bool, bool, bool, bool, bool] |
| 229 | + A six element tuple. The first element is the number of steps until the |
| 230 | + next report. The remaining elements specify whether that report will |
| 231 | + require positions, velocities, forces, energies, and periodic box info. |
| 232 | + """ |
| 233 | + steps = self._reportInterval - simulation.currentStep % self._reportInterval |
| 234 | + return steps, True, True, False, True, self._enforcePeriodicBox |
| 235 | + |
| 236 | + def report(self, simulation: Simulation, state: State) -> None: |
| 237 | + """Generate a report. |
| 238 | +
|
| 239 | + Parameters |
| 240 | + ---------- |
| 241 | + simulation : Simulation |
| 242 | + The Simulation to generate a report for |
| 243 | + state : State |
| 244 | + The current state of the simulation |
| 245 | + """ |
| 246 | + if self._nextModel == 0: |
| 247 | + self._topology = simulation.topology |
| 248 | + self._species = [ |
| 249 | + atom.element.symbol for atom in simulation.topology.atoms() |
| 250 | + ] |
| 251 | + self._time_step = ( |
| 252 | + simulation.integrator.getStepSize() * self._reportInterval |
| 253 | + ).value_in_unit(omm_unit.femtoseconds) |
| 254 | + |
| 255 | + # Get positions and velocities in Angstrom and Angstrom/fs |
| 256 | + positions = state.getPositions(asNumpy=True).value_in_unit(omm_unit.angstrom) |
| 257 | + velocities = state.getVelocities(asNumpy=True).value_in_unit( |
| 258 | + omm_unit.angstrom / omm_unit.femtosecond |
| 259 | + ) |
| 260 | + box_vectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit( |
| 261 | + omm_unit.angstrom |
| 262 | + ) |
| 263 | + |
| 264 | + # Get energies in eV |
| 265 | + kinetic_energy = ( |
| 266 | + state.getKineticEnergy() / omm_unit.AVOGADRO_CONSTANT_NA |
| 267 | + ).value_in_unit(omm_unit.ev) |
| 268 | + |
| 269 | + potential_energy = ( |
| 270 | + state.getPotentialEnergy() / omm_unit.AVOGADRO_CONSTANT_NA |
| 271 | + ).value_in_unit(omm_unit.ev) |
| 272 | + |
| 273 | + self._positions.append(positions) |
| 274 | + self._velocities.append(velocities) |
| 275 | + self._lattices.append(box_vectors) |
| 276 | + self._frame_properties.append( |
| 277 | + { |
| 278 | + "kinetic_energy": kinetic_energy, |
| 279 | + "potential_energy": potential_energy, |
| 280 | + "total_energy": kinetic_energy + potential_energy, |
| 281 | + } |
| 282 | + ) |
| 283 | + |
| 284 | + self._nextModel += 1 |
| 285 | + |
| 286 | + def save(self) -> None: |
| 287 | + """Write accumulated trajectory data to a pymatgen Trajectory object.""" |
| 288 | + if not self._positions: |
| 289 | + return |
| 290 | + |
| 291 | + velocities = [ |
| 292 | + [tuple(site_vel) for site_vel in frame_vel] |
| 293 | + for frame_vel in self._velocities |
| 294 | + ] |
| 295 | + |
| 296 | + # Format site properties as list of dicts, one per frame |
| 297 | + site_properties = [] |
| 298 | + n_frames = len(self._positions) |
| 299 | + site_properties = [{"velocities": velocities[i]} for i in range(n_frames)] |
| 300 | + |
| 301 | + # Create trajectory with positions and lattices |
| 302 | + trajectory = Trajectory( |
| 303 | + species=self._species, |
| 304 | + coords=self._positions, |
| 305 | + lattice=self._lattices, |
| 306 | + frame_properties=self._frame_properties, |
| 307 | + site_properties=site_properties, # Now properly formatted as list of dicts |
| 308 | + time_step=self._time_step, |
| 309 | + ) |
| 310 | + |
| 311 | + # Store trajectory as a class attribute so it can be accessed after deletion |
| 312 | + self.trajectory = trajectory |
| 313 | + |
| 314 | + # write out trajectory to a file |
| 315 | + with open(self._file, mode="w") as file: |
| 316 | + file.write(trajectory.to_json()) |
0 commit comments