Skip to content

Commit 20b758c

Browse files
[Breaking] ML forcefields - use emmet trajectory by default (#1219)
1 parent a804a9a commit 20b758c

File tree

10 files changed

+127
-120
lines changed

10 files changed

+127
-120
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
"PyYAML",
2929
"click",
3030
"custodian>=2024.4.18",
31-
"emmet-core>=0.84.9",
31+
"emmet-core>=v0.84.10rc2",
3232
"jobflow>=0.1.11",
3333
"monty>=2024.12.10",
3434
"numpy",
@@ -100,7 +100,7 @@ strict = [
100100
"click==8.2.1",
101101
"custodian==2025.8.13",
102102
"dscribe==2.1.1",
103-
"emmet-core==0.84.9",
103+
"emmet-core==v0.84.10rc2",
104104
"ijson==3.4.0",
105105
"jobflow==0.2.0",
106106
"lobsterpy==0.5.7",
@@ -111,7 +111,7 @@ strict = [
111111
"pydantic-settings==2.10.1",
112112
"pydantic==2.11.7",
113113
"pymatgen-analysis-defects==2025.1.18",
114-
"pymatgen==2025.2.18",
114+
"pymatgen==2025.6.14",
115115
"pymongo==4.10.1",
116116
"python-ulid==3.0.0",
117117
"seekpath==2.1.0",

src/atomate2/ase/md.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ class AseMDMaker(AseMaker, metaclass=ABCMeta):
139139
traj_file : str | Path | None = None
140140
If a str or Path, the name of the file to save the MD trajectory to.
141141
If None, the trajectory is not written to disk
142-
traj_file_fmt : Literal["ase","pmg","xdatcar"]
142+
traj_file_fmt : Literal["ase","pmg","xdatcar", "parquet"]
143143
The format of the trajectory file to write.
144144
If "ase", writes an ASE .Trajectory.
145145
If "pmg", writes a Pymatgen .Trajectory.
146-
If "xdatcar, writes a VASP-style XDATCAR
146+
If "xdatcar", writes a VASP-style XDATCAR
147+
If "parquet", uses emmet.core's Trajectory object to write a high-efficiency
148+
parquet format file containing the trajectory.
147149
traj_interval : int
148150
The step interval for saving the trajectories.
149151
mb_velocity_seed : int or None
@@ -415,7 +417,7 @@ def _callback(dyn: MolecularDynamics = md_runner) -> None:
415417

416418
return AseResult(
417419
final_mol_or_struct=mol_or_struct,
418-
trajectory=md_observer.to_pymatgen_trajectory(filename=None),
420+
trajectory=md_observer.to_emmet_trajectory(filename=None),
419421
dir_name=os.getcwd(),
420422
elapsed_time=t_f - t_i,
421423
)

src/atomate2/ase/schemas.py

Lines changed: 25 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
from pathlib import Path
1414
from typing import Any
1515

16-
from ase.stress import voigt_6_to_full_3x3_stress
17-
from ase.units import GPa
1816
from emmet.core.math import Matrix3D, Vector3D
1917
from emmet.core.structure import MoleculeMetadata, StructureMetadata
2018
from emmet.core.tasks import TaskState
19+
from emmet.core.trajectory import AtomTrajectory
2120
from emmet.core.utils import ValueEnum
2221
from emmet.core.vasp.calculation import StoreTrajectoryOption
2322
from pydantic import BaseModel, Field
2423
from pymatgen.core import Molecule, Structure
25-
from pymatgen.core.trajectory import Trajectory as PmgTrajectory
2624

2725
_task_doc_translation_keys = {
2826
"input",
@@ -49,7 +47,7 @@ class AseResult(BaseModel):
4947
None, description="The final total energy from the calculation."
5048
)
5149

52-
trajectory: PmgTrajectory | None = Field(
50+
trajectory: AtomTrajectory | None = Field(
5351
None, description="The relaxation or molecular dynamics trajectory."
5452
)
5553

@@ -146,7 +144,7 @@ class OutputDoc(AseBaseModel):
146144
# NOTE: units for stresses were converted to kbar (* -10 from standard output)
147145
# to comply with MP convention
148146
stress: Matrix3D | None = Field(
149-
None, description="The stress on the cell in units of kbar (in Voigt notation)."
147+
None, description="The stress on the cell in units of kbar."
150148
)
151149

152150
# NOTE: the ionic_steps can also be a dict when these are in blob storage and
@@ -417,22 +415,7 @@ def from_ase_compatible_result(
417415
input_mol_or_struct = None
418416
if trajectory:
419417
n_steps = len(trajectory)
420-
421-
# NOTE: convert stress units from eV/A³ to kBar (* -1 from standard output)
422-
# and to 3x3 matrix to comply with MP convention
423-
if n_steps:
424-
for idx in range(n_steps):
425-
if trajectory.frame_properties[idx].get("stress") is not None:
426-
trajectory.frame_properties[idx]["stress"] = (
427-
voigt_6_to_full_3x3_stress(
428-
[
429-
val * -10 / GPa
430-
for val in trajectory.frame_properties[idx]["stress"]
431-
]
432-
)
433-
)
434-
435-
input_mol_or_struct = trajectory[0]
418+
input_mol_or_struct = trajectory.to_pmg(frame_props=tuple(), indices=0)[0]
436419

437420
input_doc = InputDoc(
438421
mol_or_struct=input_mol_or_struct,
@@ -450,63 +433,43 @@ def from_ase_compatible_result(
450433
steps = 1
451434
n_steps = 1
452435

453-
if isinstance(input_mol_or_struct, Structure):
454-
traj_method = "from_structures"
455-
elif isinstance(input_mol_or_struct, Molecule):
456-
traj_method = "from_molecules"
457-
458-
trajectory = getattr(PmgTrajectory, traj_method)(
459-
[input_mol_or_struct],
460-
frame_properties=[trajectory.frame_properties[0]],
461-
constant_lattice=False,
462-
)
436+
if trajectory:
437+
trajectory = trajectory[-1]
463438
output_mol_or_struct = input_mol_or_struct
464439
else:
465440
output_mol_or_struct = result.final_mol_or_struct
466441

467-
if trajectory is None:
468-
final_energy = result.final_energy
469-
final_forces = None
470-
final_stress = None
471-
ionic_steps = None
442+
final_energy = result.final_energy
443+
final_forces = None
444+
final_stress = None
445+
ionic_steps = None
472446

473-
else:
474-
final_energy = trajectory.frame_properties[-1]["energy"]
475-
final_forces = trajectory.frame_properties[-1]["forces"]
476-
final_stress = trajectory.frame_properties[-1].get("stress")
447+
if trajectory:
448+
final_energy = trajectory.energy[-1]
449+
final_forces = trajectory.forces[-1]
450+
ionic_step_props = ["energy", "forces"]
451+
if trajectory.stress:
452+
final_stress = trajectory.stress[-1]
453+
ionic_step_props.append("stress")
454+
455+
if trajectory.magmoms:
456+
ionic_step_props.append("magmoms")
477457

478458
ionic_steps = []
479459
if ionic_step_data is not None and len(ionic_step_data) > 0:
480460
for idx in range(n_steps):
481461
_ionic_step_data = {
482462
key: (
483-
trajectory.frame_properties[idx].get(key)
463+
getattr(trajectory, key)[idx]
484464
if key in ionic_step_data
485465
else None
486466
)
487-
for key in ("energy", "forces", "stress")
467+
for key in ionic_step_props
488468
}
489469

490-
current_mol_or_struct = (
491-
trajectory[idx]
492-
if any(
493-
v in ionic_step_data
494-
for v in ("mol_or_struct", "structure", "molecule")
495-
)
496-
else None
497-
)
498-
499-
# include "magmoms" in `ionic_step` if the trajectory has "magmoms"
500-
if "magmoms" in trajectory.frame_properties[idx]:
501-
_ionic_step_data.update(
502-
{
503-
"magmoms": (
504-
trajectory.frame_properties[idx]["magmoms"]
505-
if "magmoms" in ionic_step_data
506-
else None
507-
)
508-
}
509-
)
470+
current_mol_or_struct = trajectory.to_pmg(
471+
frame_props=tuple(), indices=-1
472+
)[0]
510473

511474
ionic_step = IonicStep(
512475
mol_or_struct=current_mol_or_struct,

src/atomate2/ase/utils.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from ase.mep.neb import NEB
2323
from ase.optimize import BFGS, FIRE, LBFGS, BFGSLineSearch, LBFGSLineSearch, MDMin
2424
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
25+
from ase.stress import voigt_6_to_full_3x3_stress
26+
from ase.units import GPa
2527
from emmet.core.neb import NebMethod, NebResult
28+
from emmet.core.trajectory import AtomTrajectory
2629
from monty.serialization import dumpfn
2730
from pymatgen.core.structure import Molecule, Structure
2831
from pymatgen.core.trajectory import Trajectory as PmgTrajectory
@@ -83,7 +86,7 @@ def __init__(self, atoms: Atoms, store_md_outputs: bool = False) -> None:
8386
self.forces: list[np.ndarray] = []
8487

8588
self._calc_kwargs = {
86-
"stress": (
89+
"stresses": (
8790
"stress" in self.atoms.calc.implemented_properties and self._is_periodic
8891
),
8992
"magmoms": True,
@@ -113,7 +116,7 @@ def __call__(self) -> None:
113116
# When _store_md_outputs is True, ideal gas contribution to
114117
# stress is included.
115118
# Only store stress for periodic systems.
116-
if self._calc_kwargs["stress"]:
119+
if self._calc_kwargs["stresses"]:
117120
self.stresses.append(
118121
self.atoms.get_stress(include_ideal_gas=self._store_md_outputs)
119122
)
@@ -144,7 +147,7 @@ def compute_energy(self) -> float:
144147
def save(
145148
self,
146149
filename: str | PathLike | None,
147-
fmt: Literal["pmg", "ase", "xdatcar"] = "ase",
150+
fmt: Literal["pmg", "ase", "xdatcar", "parquet"] = "ase",
148151
) -> None:
149152
"""
150153
Save the trajectory file using monty.serialization.
@@ -162,6 +165,8 @@ def save(
162165
self.to_pymatgen_trajectory(filename=filename, file_format=fmt) # type: ignore[arg-type]
163166
elif fmt == "ase":
164167
self.to_ase_trajectory(filename=filename)
168+
elif fmt == "parquet":
169+
self.to_emmet_trajectory(filename=filename)
165170
else:
166171
raise ValueError(f"Unknown trajectory format {fmt}.")
167172

@@ -189,7 +194,7 @@ def to_ase_trajectory(
189194
"energy": self.energies[idx],
190195
"forces": self.forces[idx],
191196
}
192-
if self._calc_kwargs["stress"]:
197+
if self._calc_kwargs["stresses"]:
193198
kwargs["stress"] = self.stresses[idx]
194199
if self._calc_kwargs["magmoms"]:
195200
kwargs["magmom"] = self.magmoms[idx]
@@ -218,7 +223,7 @@ def to_pymatgen_trajectory(
218223
If "xdatcar", writes a VASP-format XDATCAR object to file
219224
"""
220225
frame_property_keys = ["energy", "forces"]
221-
for k in ("stress", "magmoms", "velocities", "temperature"):
226+
for k in ("stresses", "magmoms", "velocities", "temperature"):
222227
if self._calc_kwargs[k]:
223228
frame_property_keys += [k]
224229

@@ -276,12 +281,47 @@ def to_pymatgen_trajectory(
276281

277282
return pmg_traj
278283

284+
def to_emmet_trajectory(
285+
self, filename: str | PathLike | None = None
286+
) -> AtomTrajectory:
287+
"""Create an emmet.core.AtomTrajectory."""
288+
frame_props = {
289+
"cells": "lattice",
290+
"energies": "energy",
291+
"forces": "forces",
292+
"stresses": "stress",
293+
"magmoms": "magmoms",
294+
"velocities": "velocities",
295+
"temperatures": "temperature",
296+
}
297+
for k in ("stresses", "magmoms"):
298+
if not self._calc_kwargs[k]:
299+
frame_props.pop(k)
300+
301+
ionic_step_data = {v: getattr(self, k) for k, v in frame_props.items()}
302+
if self._calc_kwargs["stresses"]:
303+
# NOTE: convert stress units from eV/A³ to kBar (* -1 from standard output)
304+
# and to 3x3 matrix to comply with MP convention
305+
ionic_step_data["stress"] = [
306+
voigt_6_to_full_3x3_stress(val * -10 / GPa) for val in self.stresses
307+
]
308+
309+
traj = AtomTrajectory(
310+
elements=self.atoms.get_atomic_numbers(),
311+
cart_coords=self.atom_positions,
312+
num_ionic_steps=len(self.atom_positions),
313+
**ionic_step_data,
314+
)
315+
if filename:
316+
traj.to(file_name=filename)
317+
return traj
318+
279319
def as_dict(self) -> dict:
280320
"""Make JSONable dict representation of the Trajectory."""
281321
traj_dict = {
282322
"energy": self.energies,
283323
"forces": self.forces,
284-
"stress": self.stresses,
324+
"stresses": self.stresses,
285325
"atom_positions": self.atom_positions,
286326
"cells": self.cells,
287327
"atoms": self.atoms,
@@ -413,9 +453,9 @@ def relax(
413453
struct = self.ase_adaptor.get_structure(
414454
atoms, cls=Molecule if is_mol else Structure
415455
)
416-
traj = obs.to_pymatgen_trajectory(None)
456+
traj = obs.to_emmet_trajectory(filename=None)
417457
is_force_conv = all(
418-
np.linalg.norm(traj.frame_properties[-1]["forces"][idx]) < abs(fmax)
458+
np.linalg.norm(traj.forces[-1][idx]) < abs(fmax)
419459
for idx in range(len(struct))
420460
)
421461

@@ -434,9 +474,7 @@ def relax(
434474
trajectory=traj,
435475
converged=converged,
436476
is_force_converged=is_force_conv,
437-
energy_downhill=(
438-
traj.frame_properties[-1]["energy"] < traj.frame_properties[0]["energy"]
439-
),
477+
energy_downhill=traj.energy[-1] < traj.energy[0],
440478
dir_name=os.getcwd(),
441479
elapsed_time=t_f - t_i,
442480
)

src/atomate2/forcefields/md.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,13 @@ class ForceFieldMDMaker(ForceFieldMixin, AseMDMaker):
8888
traj_file : str | Path | None = None
8989
If a str or Path, the name of the file to save the MD trajectory to.
9090
If None, the trajectory is not written to disk
91-
traj_file_fmt : Literal["ase","pmg","xdatcar"]
91+
traj_file_fmt : Literal["ase","pmg","xdatcar","parquet"]
9292
The format of the trajectory file to write.
9393
If "ase", writes an ASE .Trajectory.
9494
If "pmg", writes a Pymatgen .Trajectory.
95-
If "xdatcar, writes a VASP-style XDATCAR
95+
If "xdatcar", writes a VASP-style XDATCAR.
96+
If "parquet", uses emmet.core's Trajectory object to write a high-efficiency
97+
parquet format file containing the trajectory.
9698
traj_interval : int
9799
The step interval for saving the trajectories.
98100
mb_velocity_seed : int or None

src/atomate2/forcefields/schemas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def from_ase_compatible_result(
144144
MLFF.MACE_MP_0B3: "mace-torch",
145145
MLFF.GAP: "quippy-ase",
146146
MLFF.Nequip: "nequip",
147+
MLFF.MATPES_PBE: "matgl",
148+
MLFF.MATPES_R2SCAN: "matgl",
147149
}
148150

149151
if pkg_name := {str(k): v for k, v in model_to_pkg_map.items()}.get(

tests/ase/test_md.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pytest
1111
from jobflow import run_locally
12+
from pymatgen.io.vasp.outputs import Xdatcar
1213

1314
from atomate2.ase.md import GFNxTBMDMaker, LennardJonesMDMaker
1415
from atomate2.ase.schemas import AseStructureTaskDoc
@@ -137,7 +138,9 @@ def test_ase_npt_maker(calculator_name, lj_fcc_ne_pars, fcc_ne_structure, tmp_di
137138
reference_energies_per_atom[calculator_name]
138139
)
139140

140-
# TODO: improve XDATCAR parsing test when class is fixed in pmg
141141
assert os.path.isfile("XDATCAR")
142+
xdatcar = Xdatcar("XDATCAR")
143+
assert len(xdatcar.structures) == len(output.objects["trajectory"])
144+
assert len(xdatcar.structures) == len(output.output.ionic_steps)
142145

143146
assert len(output.objects["trajectory"]) == n_steps + 1

0 commit comments

Comments
 (0)