Skip to content

Commit 9487bf6

Browse files
shehan807Shehan M Parmar
andauthored
Add DynamicOpenMMFlowMaker for dynamic OpenMM Simulations (#1115)
* created DynamicOpenMMFlowMaker; moved BaseOpenMMMaker import outside of TYPE_CHECKING to use as callable in default_factory * use of lambda function is more appropriate for default_factory to avoid shared mutable defaults * added dynamic flow logic and apply_flow_control instace classes; added placeholder should_continue function * restructured dynamic_flow and removed apply_flow_control * removed dynamic_collect_outputs * added and passed DynamicOpenMMFlowMaker tests; fixed dynamic flow logic * undo incorrect change * created dynamic.py * created test_should_continue * mypy_extensions passes mypy, but adds additional dependency * mypy_extensions passes mypy, but adds additional dependency * replaced mypy_extensions with protocol, mypy tests pass. --------- Co-authored-by: Shehan M Parmar <[email protected]>
1 parent 4a4bfd9 commit 9487bf6

File tree

5 files changed

+381
-72
lines changed

5 files changed

+381
-72
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""Core Flows for OpenMM module."""
22

33
from atomate2.openmm.flows.core import OpenMMFlowMaker
4+
from atomate2.openmm.flows.dynamic import DynamicOpenMMFlowMaker
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""Dynamic flows for OpenMM module."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
7+
8+
import numpy as np
9+
from jobflow import CURRENT_JOB, Flow, Job, Maker, Response, job
10+
from scipy.signal import savgol_filter
11+
12+
from atomate2.openmm.flows.core import _get_calcs_reversed, collect_outputs
13+
from atomate2.openmm.jobs.base import BaseOpenMMMaker, openmm_job
14+
15+
if TYPE_CHECKING:
16+
from emmet.core.openmm import (
17+
Calculation,
18+
OpenMMFlowMaker,
19+
OpenMMInterchange,
20+
OpenMMTaskDocument,
21+
)
22+
from openff.interchange import Interchange
23+
24+
25+
def _get_final_jobs(input_jobs: list[Job] | Flow) -> list[Job]:
26+
"""Unwrap nested jobs from a dynamic flow."""
27+
jobs = input_jobs.jobs if isinstance(input_jobs, Flow) else input_jobs
28+
if not jobs:
29+
return []
30+
31+
# check if the last job is a flow with .maker.jobs
32+
last = jobs[-1]
33+
if (
34+
hasattr(last, "maker")
35+
and hasattr(last.maker, "jobs")
36+
and isinstance(last.maker.jobs, list)
37+
):
38+
# recursively explore .maker.jobs
39+
return _get_final_jobs(last.maker.jobs)
40+
return jobs
41+
42+
43+
@openmm_job
44+
def default_should_continue(
45+
task_docs: list[OpenMMTaskDocument],
46+
stage_index: int,
47+
max_stages: int,
48+
physical_property: str = "potential_energy",
49+
target: float | None = None,
50+
threshold: float = 1e-3,
51+
burn_in_ratio: float = 0.2,
52+
) -> Response:
53+
"""Decide dynamic flow logic (True).
54+
55+
This serves as a template for any bespoke "should_continue" functions
56+
written by the user. By default, simulation logic depends on stability
57+
of potential energy as a function of time, dU_dt.
58+
"""
59+
task_doc = task_docs[-1]
60+
61+
# get key physical parameters from calculation list
62+
potential_energy: list[float] = []
63+
density: list[float] = []
64+
for doc in task_docs:
65+
potential_energy.extend(doc.calcs_reversed[0].output.potential_energy)
66+
density.extend(doc.calcs_reversed[0].output.density)
67+
dt = doc.calcs_reversed[0].input.state_interval
68+
69+
if physical_property == "density":
70+
values = np.array(density)
71+
elif physical_property == "potential_energy":
72+
values = np.array(potential_energy)
73+
74+
# toss out first X% of values, default 20%
75+
burn_in = int(burn_in_ratio * len(values))
76+
values = values[burn_in:]
77+
window_length = max(5, burn_in + 1) if burn_in % 2 == 0 else max(5, burn_in)
78+
79+
avg = np.mean(values)
80+
dvalue_dt = savgol_filter(
81+
values / avg, window_length, polyorder=3, deriv=1, delta=dt
82+
)
83+
decay_rate = np.max(np.abs(dvalue_dt))
84+
job = CURRENT_JOB.job
85+
86+
if target:
87+
delta = np.abs((avg - target) / target)
88+
should_continue = not delta < threshold
89+
job.append_name(
90+
f" [Stage {stage_index}, delta={delta:.3e}"
91+
f"-> should_continue={should_continue}]"
92+
)
93+
elif stage_index > max_stages or decay_rate < threshold: # max_stages exceeded
94+
should_continue = False
95+
else: # decay_rate not stable
96+
should_continue = True
97+
98+
job.append_name(
99+
f" [Stage {stage_index}, decay_rate={decay_rate:.3e}"
100+
f"-> should_continue={should_continue}]"
101+
)
102+
103+
task_doc.should_continue = should_continue
104+
return Response(output=task_doc)
105+
106+
107+
@runtime_checkable
108+
class ShouldContinueProtocol(Protocol):
109+
"""Protocol for flexible callback types for should_continue function."""
110+
111+
def __call__(
112+
self,
113+
task_docs: list[OpenMMTaskDocument],
114+
stage_index: int,
115+
max_stages: int,
116+
physical_property: str = "potential_energy",
117+
target: float | None = None,
118+
threshold: float = 1e-3,
119+
burn_in_ratio: float = 0.2,
120+
) -> Response:
121+
"""Identical keyword arguments as default_should_continue()."""
122+
...
123+
124+
125+
@dataclass
126+
class DynamicOpenMMFlowMaker(Maker):
127+
"""Run a dynamic equlibration or production simulation.
128+
129+
Create a dynamic flow out of an existing OpenMM simulation
130+
job or a linear sequence of linked jobs, i.e., an OpenMM
131+
flow.
132+
133+
Attributes
134+
----------
135+
name : str
136+
The name of the dynamic OpenMM job or flow. Default is the name
137+
of the inherited maker name with "dynamic" prepended.
138+
tags : list[str]
139+
Tags to apply to the final job. Will only be applied if collect_jobs is True.
140+
maker: Union[BaseOpenMMMaker, OpenMMFlowMaker]
141+
A single (either job or flow) maker to make dynamic.
142+
max_stages: int
143+
Maximum number of stages to run consecutively before terminating
144+
dynamic flow logic.
145+
collect_outputs : bool
146+
If True, a final job is added that collects all jobs into a single
147+
task document.
148+
should_continue: Callable
149+
A general function for evaluating properties in `calcs_reversed`
150+
to determine simulation flow logic (i.e., termination, pausing,
151+
or continuing).
152+
jobs: list[BaseOpenMMMaker | OpenMMFlowMaker]
153+
A running list of jobs in simulation flow.
154+
job_uuids: list
155+
A running list of job uuids in simulation flow.
156+
calcs_reversed: list[Calculation]
157+
A running list of Calculations in simulation flow.
158+
"""
159+
160+
name: str = field(default=None)
161+
tags: list[str] = field(default_factory=list)
162+
maker: BaseOpenMMMaker | OpenMMFlowMaker = field(
163+
default_factory=lambda: BaseOpenMMMaker()
164+
)
165+
max_stages: int = field(default=5)
166+
collect_outputs: bool = True
167+
should_continue: ShouldContinueProtocol = field(
168+
default_factory=lambda: default_should_continue
169+
)
170+
171+
jobs: list = field(default_factory=list)
172+
job_uuids: list = field(default_factory=list)
173+
calcs_reversed: list[Calculation] = field(default_factory=list)
174+
stage_task_type: str = "collect"
175+
176+
def __post_init__(self) -> None:
177+
"""Post init formatting of arguments."""
178+
if self.name is None:
179+
self.name = f"dynamic {self.maker.name}"
180+
181+
def make(
182+
self,
183+
interchange: Interchange | OpenMMInterchange | str,
184+
prev_dir: str | None = None,
185+
) -> Flow:
186+
"""Run the dynamic simulation using the provided Interchange object.
187+
188+
Parameters
189+
----------
190+
interchange : Interchange
191+
The Interchange object containing the system
192+
to simulate.
193+
prev_task : Optional[ClassicalMDTaskDocument]
194+
The directory of the previous task.
195+
196+
Returns
197+
-------
198+
Flow
199+
A Flow object containing the OpenMM jobs for the simulation.
200+
"""
201+
# Run initial stage job
202+
stage_job_0 = self.maker.make(
203+
interchange=interchange,
204+
prev_dir=prev_dir,
205+
)
206+
self.jobs.append(stage_job_0)
207+
208+
# collect the uuids and calcs for the final collect job
209+
if isinstance(stage_job_0, Flow):
210+
self.job_uuids.extend(stage_job_0.job_uuids)
211+
else:
212+
self.job_uuids.append(stage_job_0.uuid)
213+
self.calcs_reversed.append(_get_calcs_reversed(stage_job_0))
214+
215+
# Determine stage job control logic
216+
control_stage_0 = self.should_continue(
217+
task_docs=[stage_job_0.output],
218+
stage_index=0,
219+
max_stages=self.max_stages,
220+
)
221+
self.jobs.append(control_stage_0)
222+
223+
stage_n = self.dynamic_flow(
224+
prev_stage_index=0,
225+
prev_docs=[control_stage_0.output],
226+
)
227+
self.jobs.append(stage_n)
228+
return Flow([stage_job_0, control_stage_0, stage_n], output=stage_n.output)
229+
230+
@job
231+
def dynamic_flow(
232+
self,
233+
prev_stage_index: int,
234+
prev_docs: list[OpenMMTaskDocument],
235+
) -> Response | None:
236+
"""Run stage n and dynamically decide to continue or terminate flow."""
237+
prev_stage = prev_docs[-1]
238+
239+
# stage control logic: (a) begin, (b) continue, (c) terminate, (d) pause
240+
if (
241+
prev_stage_index >= self.max_stages or not prev_stage.should_continue
242+
): # pause
243+
if self.collect_outputs:
244+
collect_job = collect_outputs(
245+
prev_stage.dir_name,
246+
tags=self.tags or None,
247+
job_uuids=self.job_uuids,
248+
calcs_reversed=self.calcs_reversed,
249+
task_type=self.stage_task_type,
250+
)
251+
return Response(replace=collect_job, output=collect_job.output)
252+
return Response(output=prev_stage)
253+
254+
stage_index = prev_stage_index + 1
255+
256+
stage_job_n = self.maker.make(
257+
interchange=prev_stage.interchange,
258+
prev_dir=prev_stage.dir_name,
259+
)
260+
self.jobs.append(stage_job_n)
261+
262+
# collect the uuids and calcs for the final collect job
263+
if isinstance(stage_job_n, Flow):
264+
self.job_uuids.extend(stage_job_n.job_uuids)
265+
else:
266+
self.job_uuids.append(stage_job_n.uuid)
267+
self.calcs_reversed.append(_get_calcs_reversed(stage_job_n))
268+
269+
control_stage_n = self.should_continue(
270+
task_docs=[*prev_docs, stage_job_n.output],
271+
stage_index=stage_index,
272+
max_stages=self.max_stages,
273+
)
274+
self.jobs.append(control_stage_n)
275+
276+
next_stage_n = self.dynamic_flow(
277+
prev_stage_index=stage_index,
278+
prev_docs=[*prev_docs, control_stage_n.output],
279+
)
280+
self.jobs.append(next_stage_n)
281+
stage_n_flow = Flow([stage_job_n, control_stage_n, next_stage_n])
282+
283+
return Response(replace=stage_n_flow, output=next_stage_n.output)

tests/openmm_md/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
from jobflow import run_locally
44

55

6+
@pytest.fixture
7+
def run_dynamic_job(tmp_path):
8+
def run_dynamic_job(job):
9+
response_dict = run_locally(job, ensure_success=True, root_dir=tmp_path)
10+
return list(response_dict.values())[-1][2].output
11+
12+
return run_dynamic_job
13+
14+
615
@pytest.fixture
716
def run_job(tmp_path):
817
def run_job(job):

tests/openmm_md/flows/test_core.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
from __future__ import annotations
22

33
import io
4-
import json
54
from pathlib import Path
65

7-
import numpy as np
8-
import pytest
96
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
107
from jobflow import Flow
118
from MDAnalysis import Universe
12-
from monty.json import MontyDecoder
139
from openmm.app import PDBFile
1410

1511
from atomate2.openmm.flows.core import OpenMMFlowMaker
@@ -176,71 +172,3 @@ def test_flow_maker(interchange, run_job):
176172
u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory5.dcd"))
177173

178174
assert len(u.trajectory) == 5
179-
180-
181-
def test_traj_blob_embed(interchange, run_job, tmp_path):
182-
nvt = NVTMaker(n_steps=2, traj_interval=1, embed_traj=True)
183-
184-
# Run the ProductionMaker flow
185-
nvt_job = nvt.make(interchange)
186-
task_doc = run_job(nvt_job)
187-
188-
interchange = OpenMMInterchange.model_validate_json(task_doc.interchange)
189-
topology = PDBFile(io.StringIO(interchange.topology)).getTopology()
190-
191-
u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory.dcd"))
192-
193-
assert len(u.trajectory) == 2
194-
195-
calc_output = task_doc.calcs_reversed[0].output
196-
assert calc_output.traj_blob is not None
197-
198-
# Write the bytes back to a file
199-
with open(tmp_path / "doc_trajectory.dcd", "wb") as f:
200-
f.write(bytes.fromhex(calc_output.traj_blob))
201-
202-
u2 = Universe(topology, str(tmp_path / "doc_trajectory.dcd"))
203-
204-
assert np.all(u.atoms.positions == u2.atoms.positions)
205-
206-
with open(Path(task_doc.dir_name) / "taskdoc.json") as file:
207-
task_dict = json.load(file, cls=MontyDecoder)
208-
task_doc_parsed = OpenMMTaskDocument.model_validate(task_dict)
209-
210-
parsed_output = task_doc_parsed.calcs_reversed[0].output
211-
212-
assert parsed_output.traj_blob == calc_output.traj_blob
213-
214-
215-
@pytest.mark.skip("for local testing and debugging")
216-
def test_fireworks(interchange):
217-
# Create an instance of ProductionMaker with custom parameters
218-
219-
production_maker = OpenMMFlowMaker(
220-
name="test_production",
221-
tags=["test"],
222-
makers=[
223-
EnergyMinimizationMaker(max_iterations=1),
224-
NPTMaker(n_steps=5, pressure=1.0, state_interval=1, traj_interval=1),
225-
OpenMMFlowMaker.anneal_flow(anneal_temp=400, final_temp=300, n_steps=5),
226-
NVTMaker(n_steps=5),
227-
],
228-
)
229-
230-
interchange_json = interchange.json()
231-
# interchange_bytes = interchange_json.encode("utf-8")
232-
233-
# Run the ProductionMaker flow
234-
production_flow = production_maker.make(interchange_json)
235-
236-
from fireworks import LaunchPad
237-
from jobflow.managers.fireworks import flow_to_workflow
238-
239-
wf = flow_to_workflow(production_flow)
240-
241-
lpad = LaunchPad.auto_load()
242-
lpad.add_wf(wf)
243-
244-
# from fireworks.core.rocket_launcher import launch_rocket
245-
#
246-
# launch_rocket(lpad)

0 commit comments

Comments
 (0)