Skip to content

Commit 294d5ca

Browse files
make structure/molecule aliases of mol_or_struct, allow for removal from ionic step properties
1 parent 1822f9d commit 294d5ca

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

src/atomate2/ase/schemas.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from emmet.core.structure import MoleculeMetadata, StructureMetadata
2121
from emmet.core.trajectory import AtomTrajectory
2222
from emmet.core.types.enums import StoreTrajectoryOption, TaskState, ValueEnum
23-
from pydantic import BaseModel, Field
23+
from pydantic import AliasChoices, BaseModel, Field
2424
from pymatgen.core import Molecule, Structure
2525
from pymatgen.core.trajectory import Trajectory as PmgTrajectory
2626
from pymatgen.entries.computed_entries import ComputedEntry
@@ -116,17 +116,24 @@ class AseBaseModel(BaseModel):
116116
"""Base document class for ASE input and output."""
117117

118118
mol_or_struct: Structure | Molecule | None = Field(
119-
None, description="The molecule or structure at this step."
119+
None,
120+
description="The molecule or structure at this step.",
121+
validation_alias=AliasChoices("mol_or_struct", "structure", "molecule"),
120122
)
121-
structure: Structure | None = Field(None, description="The structure at this step.")
122-
molecule: Molecule | None = Field(None, description="The molecule at this step.")
123123

124-
def model_post_init(self, context: Any, /) -> None:
125-
"""Establish alias to structure and molecule fields."""
126-
if self.structure is None and isinstance(self.mol_or_struct, Structure):
127-
self.structure = self.mol_or_struct
128-
elif self.molecule is None and isinstance(self.mol_or_struct, Molecule):
129-
self.molecule = self.mol_or_struct
124+
@property
125+
def structure(self) -> Structure | None:
126+
"""Retrieve the structure associated with this document, if applicable."""
127+
if isinstance(self.mol_or_struct, Structure):
128+
return self.mol_or_struct
129+
return None
130+
131+
@property
132+
def molecule(self) -> Molecule | None:
133+
"""Retrieve the molecule associated with this document, if applicable."""
134+
if isinstance(self.mol_or_struct, Molecule):
135+
return self.mol_or_struct
136+
return None
130137

131138

132139
class IonicStep(AseBaseModel):
@@ -476,8 +483,18 @@ def from_ase_compatible_result(
476483
final_stress = None
477484
ionic_steps = None
478485

486+
if "mol_or_struct" not in (
487+
user_ionic_step_data := set(ionic_step_data or tuple())
488+
):
489+
for ms_alias in ("molecule", "structure"):
490+
if ms_alias in user_ionic_step_data:
491+
user_ionic_step_data.add("mol_or_struct")
492+
479493
if trajectory:
480494
ionic_step_props = {"energy", "forces"}
495+
if save_atoms := "mol_or_struct" in user_ionic_step_data:
496+
user_ionic_step_data.remove("mol_or_struct")
497+
481498
if isinstance(trajectory, AtomTrajectory):
482499
final_energy = trajectory.energy[-1]
483500
final_forces = trajectory.forces[-1]
@@ -501,21 +518,17 @@ def from_ase_compatible_result(
501518
ionic_step_props.add("magmoms")
502519

503520
ionic_steps = []
504-
if (
505-
len(
506-
use_ionic_step_props := ionic_step_props.intersection(
507-
ionic_step_data or set()
508-
)
509-
)
510-
> 0
511-
):
521+
use_ionic_step_props = ionic_step_props.intersection(user_ionic_step_data)
522+
if len(use_ionic_step_props) > 0:
512523
if isinstance(trajectory, AtomTrajectory):
513524
ionic_steps = [
514525
IonicStep(
515526
mol_or_struct=trajectory.to_pmg(
516527
frame_props=tuple(),
517528
indices=idx,
518-
)[0],
529+
)[0]
530+
if save_atoms
531+
else None,
519532
**{
520533
key: getattr(trajectory, key)[idx]
521534
for key in use_ionic_step_props
@@ -527,7 +540,7 @@ def from_ase_compatible_result(
527540
else:
528541
ionic_steps = [
529542
IonicStep(
530-
mol_or_struct=atoms,
543+
mol_or_struct=atoms if save_atoms else None,
531544
**{
532545
key: convert_stress_from_voigt_to_symm(
533546
trajectory.frame_properties[idx].get(key)

0 commit comments

Comments
 (0)