Skip to content

Commit 6490fda

Browse files
migrate unused openff/mm schemas + custom interchange to atomate2 (#1290)
1 parent 707fd31 commit 6490fda

File tree

12 files changed

+401
-15
lines changed

12 files changed

+401
-15
lines changed

src/atomate2/openff/schemas.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""Solvent and solvation schemas for OpenFF which are not yet production ready."""
2+
3+
from __future__ import annotations
4+
5+
from io import StringIO
6+
from typing import TYPE_CHECKING, Annotated
7+
8+
import pandas as pd
9+
from MDAnalysis import Universe
10+
from MDAnalysis.analysis.dielectric import DielectricConstant
11+
from pydantic import (
12+
BaseModel,
13+
ConfigDict,
14+
Field,
15+
PlainSerializer,
16+
PlainValidator,
17+
WithJsonSchema,
18+
)
19+
from solvation_analysis.solute import Solute
20+
from transport_analysis.viscosity import ViscosityHelfand
21+
22+
if TYPE_CHECKING:
23+
from typing import Any
24+
25+
26+
def data_frame_validater(o: Any) -> pd.DataFrame:
27+
"""Define custom validator for pandas DataFrame.
28+
29+
Parameters
30+
----------
31+
o : Any
32+
33+
Returns
34+
-------
35+
pandas DataFrame
36+
"""
37+
if isinstance(o, pd.DataFrame):
38+
return o
39+
if isinstance(o, str):
40+
return pd.read_csv(StringIO(o))
41+
raise ValueError(f"Invalid DataFrame: {o}")
42+
43+
44+
def data_frame_serializer(df: pd.DataFrame) -> str:
45+
"""Serialize pandas DataFrame as CSV."""
46+
return df.to_csv()
47+
48+
49+
DataFrame = Annotated[
50+
pd.DataFrame,
51+
PlainValidator(data_frame_validater),
52+
PlainSerializer(data_frame_serializer),
53+
WithJsonSchema({"type": "string"}),
54+
]
55+
56+
57+
class SolventBenchmarkingDoc(BaseModel):
58+
"""Define document for benchmarking solvent properties."""
59+
60+
model_config = ConfigDict(arbitrary_types_allowed=True)
61+
62+
density: float | None = Field(None, description="Density of the solvent")
63+
64+
viscosity_function_values: list[float] | None = Field(
65+
None, description="Viscosity function over time"
66+
)
67+
68+
viscosity: float | None = Field(None, description="Viscosity of the solvent")
69+
70+
dielectric: float | None = Field(
71+
None, description="Dielectric constant of the solvent"
72+
)
73+
74+
job_uuid: str | None = Field(
75+
None, description="The UUID of the flow that generated this data."
76+
)
77+
78+
flow_uuid: str | None = Field(
79+
None, description="The UUID of the top level host from that job."
80+
)
81+
82+
dielectric_run_kwargs: dict | None = Field(
83+
None, description="kwargs passed to the DielectricConstant.run method"
84+
)
85+
86+
viscosity_run_kwargs: dict | None = Field(
87+
None, description="kwargs passed to the ViscosityHelfand.run method"
88+
)
89+
90+
tags: list[str] | None = Field(
91+
[], title="tag", description="Metadata tagged to the parent job."
92+
)
93+
94+
@classmethod
95+
def from_universe(
96+
cls,
97+
u: Universe,
98+
temperature: float | None = None,
99+
density: float | None = None,
100+
job_uuid: str | None = None,
101+
flow_uuid: str | None = None,
102+
dielectric_run_kwargs: dict | None = None,
103+
viscosity_run_kwargs: dict | None = None,
104+
tags: list[str] | None = None,
105+
) -> SolventBenchmarkingDoc:
106+
"""Create document from openmm Universe."""
107+
if temperature is not None:
108+
dielectric = DielectricConstant(
109+
u.atoms, temperature=temperature, make_whole=False
110+
)
111+
dielectric_run_kwargs = dielectric_run_kwargs or {}
112+
dielectric.run(**dielectric_run_kwargs)
113+
eps = dielectric.results.eps_mean
114+
else:
115+
eps = None
116+
117+
if u.atoms.ts.has_velocities:
118+
start, stop = int(0.2 * len(u.trajectory)), int(0.8 * len(u.trajectory))
119+
viscosity_helfand = ViscosityHelfand(
120+
u.atoms,
121+
temp_avg=temperature,
122+
linear_fit_window=(start, stop),
123+
)
124+
viscosity_run_kwargs = viscosity_run_kwargs or {}
125+
viscosity_helfand.run(**viscosity_run_kwargs)
126+
viscosity_function_values = viscosity_helfand.results.timeseries.tolist()
127+
viscosity = viscosity_helfand.results.viscosity
128+
129+
else:
130+
viscosity_function_values = None
131+
viscosity = None
132+
133+
return cls(
134+
density=density,
135+
viscosity_function_values=viscosity_function_values,
136+
viscosity=viscosity,
137+
dielectric=eps,
138+
job_uuid=job_uuid,
139+
flow_uuid=flow_uuid,
140+
dielectric_run_kwargs=dielectric_run_kwargs,
141+
viscosity_run_kwargs=viscosity_run_kwargs,
142+
tags=tags,
143+
)
144+
145+
146+
# class SolvationDoc(ClassicalMDDoc, arbitrary_types_allowed=True):
147+
class SolvationDoc(BaseModel):
148+
"""Schematize solvation calculation."""
149+
150+
model_config = ConfigDict(arbitrary_types_allowed=True)
151+
152+
solute_name: str | None = Field(None, description="Name of the solute")
153+
154+
solvent_names: list[str] | None = Field(None, description="Names of the solvents")
155+
156+
is_electrolyte: bool | None = Field(
157+
None, description="Whether system is an electrolyte"
158+
)
159+
160+
# Solute.coordination
161+
162+
coordination_numbers: dict[str, float] | None = Field(
163+
None,
164+
description="A dictionary where keys are residue names and values are "
165+
"the mean coordination number of that residue.",
166+
)
167+
168+
# coordination_numbers_by_frame: DataFrame | None= Field(
169+
# None,
170+
# description="Coordination number in each frame of the trajectory.",
171+
# )
172+
173+
coordinating_atoms: DataFrame | None = Field(
174+
None,
175+
description="Fraction of each atom_type participating in solvation, "
176+
"calculated for each solvent.",
177+
)
178+
179+
coordination_vs_random: dict[str, float] | None = Field(
180+
None,
181+
description="Coordination number relative to random coordination.",
182+
)
183+
184+
# Solute.networking
185+
186+
# TODO: In the worst case, this could be extremely large.
187+
# Need to consider what else we might want from this object.
188+
# network_df: DataFrame | None= Field(
189+
# None,
190+
# description="All solute-solvent networks in the system, "
191+
# . "indexed by the `frame` and a 'network_ix'. "
192+
# "Columns are the species name and res_ix.",
193+
# )
194+
195+
network_sizes: DataFrame | None = Field(
196+
None,
197+
description="Sizes of all networks, indexed by frame. Column headers are "
198+
"network sizes, e.g. the integer number of solutes + solvents in the network."
199+
"The values in each column are the number of networks with that size in each "
200+
"frame.",
201+
)
202+
203+
solute_status: dict[str, float] | None = Field(
204+
None,
205+
description="A dictionary where the keys are the “status” of the "
206+
"solute and the values are the fraction of solute with that "
207+
"status, averaged over all frames. “isolated” means that the solute not "
208+
"coordinated with any of the networking solvents, network size is 1. "
209+
"“paired” means the solute and is coordinated with a single networking "
210+
"solvent and that solvent is not coordinated to any other solutes, "
211+
"network size is 2. “networked” means that the solute is coordinated to "
212+
"more than one solvent or its solvent is coordinated to more than one "
213+
"solute, network size >= 3.",
214+
)
215+
216+
# solute_status_by_frame: DataFrame | None= Field(
217+
# None, description="Solute status in each frame of the trajectory."
218+
# )
219+
220+
# Solute.pairing
221+
222+
solvent_pairing: dict[str, float] | None = Field(
223+
None, description="Fraction of each solvent coordinated to the solute."
224+
)
225+
226+
# pairing_by_frame: DataFrame | None= Field(
227+
# None, description="Solvent pairing in each frame."
228+
# )
229+
230+
fraction_free_solvents: dict[str, float] | None = Field(
231+
None, description="Fraction of each solvent not coordinated to solute."
232+
)
233+
234+
diluent_composition: dict[str, float] | None = Field(
235+
None, description="Fraction of diluent constituted by each solvent."
236+
)
237+
238+
# diluent_composition_by_frame: DataFrame | None= Field(
239+
# None, description="Diluent composition in each frame."
240+
# )
241+
242+
diluent_counts: DataFrame | None = Field(
243+
None, description="Solvent counts in each frame."
244+
)
245+
246+
# Solute.residence
247+
248+
residence_times: dict[str, float] | None = Field(
249+
None,
250+
description="Average residence time of each solvent."
251+
"Calculated by 1/e cutoff on autocovariance function.",
252+
)
253+
254+
residence_times_fit: dict[str, float] | None = Field(
255+
None,
256+
description="Average residence time of each solvent."
257+
"Calculated by fitting the autocovariance function to an exponential decay.",
258+
)
259+
260+
# Solute.speciation
261+
262+
speciation_fraction: DataFrame | None = Field(
263+
None, description="Fraction of shells of each type."
264+
)
265+
266+
solvent_co_occurrence: DataFrame | None = Field(
267+
None,
268+
description="The actual co-occurrence of solvents divided by "
269+
"the expected co-occurrence in randomly distributed solvation shells."
270+
"i.e. given a molecule of solvent i in the shell, the probability of "
271+
"solvent j's presence relative to choosing a solvent at random "
272+
"from the pool of all coordinated solvents. ",
273+
)
274+
275+
job_uuid: str | None = Field(
276+
None, description="The UUID of the flow that generated this data."
277+
)
278+
279+
flow_uuid: str | None = Field(
280+
None, description="The UUID of the top level host from that job."
281+
)
282+
283+
@classmethod
284+
def from_solute(
285+
cls,
286+
solute: Solute,
287+
job_uuid: str | None = None,
288+
flow_uuid: str | None = None,
289+
) -> SolvationDoc:
290+
"""Create a SolvationDoc from openmm Solute."""
291+
# as a dict
292+
props = {
293+
"solute_name": solute.solute_name,
294+
"solvent_names": list(solute.solvents.keys()),
295+
"is_electrolyte": True,
296+
"job_uuid": job_uuid,
297+
"flow_uuid": flow_uuid,
298+
}
299+
if hasattr(solute, "coordination"):
300+
for k in (
301+
"coordination_numbers",
302+
"coordinating_atoms",
303+
"coordination_vs_random",
304+
):
305+
props[k] = getattr(solute.coordination, k, None)
306+
if hasattr(solute, "pairing"):
307+
for k in (
308+
"solvent_pairing",
309+
"fraction_free_solvents",
310+
"diluent_composition",
311+
"diluent_counts",
312+
):
313+
props[k] = getattr(solute.pairing, k, None)
314+
if hasattr(solute, "speciation"):
315+
for k in ("speciation_fraction", "solvent_co_occurrence"):
316+
props[k] = getattr(solute.speciation, k, None)
317+
if hasattr(solute, "networking"):
318+
for k in ("network_sizes", "solute_status"):
319+
props[k] = getattr(solute.networking, k, None)
320+
if hasattr(solute, "residence"):
321+
for k, v in {
322+
"residence_times_cutoff": "residence_times",
323+
"residence_times_fit": "residence_times_fit",
324+
}.items():
325+
props[v] = getattr(solute.residence, k, None)
326+
327+
return SolvationDoc(**props)

src/atomate2/openmm/flows/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING
99

10-
from emmet.core.openmm import Calculation, OpenMMInterchange, OpenMMTaskDocument
10+
from emmet.core.openmm import Calculation, OpenMMTaskDocument
1111
from jobflow import Flow, Job, Maker, Response
1212
from monty.json import MontyDecoder, MontyEncoder
1313

@@ -18,6 +18,7 @@
1818
if TYPE_CHECKING:
1919
from openff.interchange import Interchange
2020

21+
from atomate2.openmm.interchange import OpenMMInterchange
2122
from atomate2.openmm.jobs.base import BaseOpenMMMaker
2223

2324

src/atomate2/openmm/flows/dynamic.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
from atomate2.openmm.jobs.base import BaseOpenMMMaker, openmm_job
1414

1515
if TYPE_CHECKING:
16-
from emmet.core.openmm import (
17-
Calculation,
18-
OpenMMFlowMaker,
19-
OpenMMInterchange,
20-
OpenMMTaskDocument,
21-
)
16+
from emmet.core.openmm import Calculation, OpenMMFlowMaker, OpenMMTaskDocument
2217
from openff.interchange import Interchange
2318

19+
from atomate2.openmm.interchange import OpenMMInterchange
20+
2421

2522
def _get_final_jobs(input_jobs: list[Job] | Flow) -> list[Job]:
2623
"""Unwrap nested jobs from a dynamic flow."""

0 commit comments

Comments
 (0)