55from functools import cached_property
66import hashlib
77from importlib import import_module
8+ import json
89from monty .serialization import loadfn
910import os
10- import numpy as np
1111from pathlib import Path
12- from pydantic import BaseModel , Field , model_validator , model_serializer , PrivateAttr
13- from typing import TYPE_CHECKING , Any , Optional
12+ from pydantic import (
13+ BaseModel ,
14+ Field ,
15+ model_validator ,
16+ model_serializer ,
17+ PrivateAttr ,
18+ PlainSerializer ,
19+ BeforeValidator ,
20+ )
21+ from typing import TYPE_CHECKING , Any , Annotated , TypeAlias , TypeVar
1422
1523from pymatgen .core import Structure
1624from pymatgen .io .vasp .inputs import POTCAR_STATS_PATH , Incar , Kpoints , Poscar , Potcar , PmgVaspPspDirError
2230
2331if TYPE_CHECKING :
2432 from typing_extensions import Self
33+ from monty .json import MSONable
2534
2635SETTINGS = IOValidationSettings ()
2736
2837
38+ def _msonable_from_str (obj : Any , cls : type [MSONable ]) -> MSONable :
39+ if isinstance (obj , str ):
40+ obj = json .loads (obj )
41+ if isinstance (obj , dict ):
42+ return cls .from_dict (obj )
43+ return obj
44+
45+
46+ IncarTypeVar = TypeVar ("IncarTypeVar" , Incar , str )
47+ IncarType : TypeAlias = Annotated [
48+ IncarTypeVar ,
49+ BeforeValidator (lambda x : _msonable_from_str (x , Incar )),
50+ PlainSerializer (lambda x : json .dumps (x .as_dict ()), return_type = str ),
51+ ]
52+
53+ KpointsTypeVar = TypeVar ("KpointsTypeVar" , Kpoints , str )
54+ KpointsType : TypeAlias = Annotated [
55+ KpointsTypeVar ,
56+ BeforeValidator (lambda x : _msonable_from_str (x , Kpoints )),
57+ PlainSerializer (lambda x : json .dumps (x .as_dict ()), return_type = str ),
58+ ]
59+
60+ StructureTypeVar = TypeVar ("StructureTypeVar" , Structure , str )
61+ StructureType : TypeAlias = Annotated [
62+ StructureTypeVar ,
63+ BeforeValidator (lambda x : _msonable_from_str (x , Structure )),
64+ PlainSerializer (lambda x : json .dumps (x .as_dict ()), return_type = str ),
65+ ]
66+
67+
2968class ValidationError (Exception ):
3069 """Define custom exception during validation."""
3170
@@ -62,8 +101,8 @@ class PotcarSummaryStatistics(BaseModel):
62101class PotcarSummaryStats (BaseModel ):
63102 """Schematize `PotcarSingle._summary_stats`."""
64103
65- keywords : Optional [ PotcarSummaryKeywords ] = None
66- stats : Optional [ PotcarSummaryStatistics ] = None
104+ keywords : PotcarSummaryKeywords | None = None
105+ stats : PotcarSummaryStatistics | None = None
67106 titel : str
68107 lexch : str
69108
@@ -80,23 +119,41 @@ def from_file(cls, potcar_path: os.PathLike | Potcar) -> list[Self]:
80119class LightOutcar (BaseModel ):
81120 """Schematic of pymatgen's Outcar."""
82121
83- drift : Optional [ list [list [float ]]] = Field (None , description = "The drift forces." )
84- magnetization : Optional [ list [dict [str , float ]]] = Field (
122+ drift : list [list [float ]] | None = Field (None , description = "The drift forces." )
123+ magnetization : list [dict [str , float ]] | None = Field (
85124 None , description = "The on-site magnetic moments, possibly with orbital resolution."
86125 )
87126
88127
128+ class LightElectronicStep (BaseModel ):
129+ """Lightweight representation of electronic step data from VASP."""
130+
131+ e_0_energy : float | None = None
132+ e_fr_energy : float | None = None
133+ e_wo_entrp : float | None = None
134+ eentropy : float | None = None
135+
136+
137+ class LightIonicStep (BaseModel ):
138+ """Lightweight representation of ionic step data from VASP."""
139+
140+ e_0_energy : float | None = None
141+ e_fr_energy : float | None = None
142+ forces : list [list [float ]] | None = None
143+ electronic_steps : list [LightElectronicStep ] | None = None
144+
145+
89146class LightVasprun (BaseModel ):
90147 """Lightweight version of pymatgen Vasprun."""
91148
92149 vasp_version : str = Field (description = "The dot-separated version of VASP used." )
93- ionic_steps : list [dict [str , Any ]] = Field (description = "The ionic steps in the calculation." )
94150 final_energy : float = Field (description = "The final total energy in eV." )
95- final_structure : Structure = Field (description = "The final structure." )
96- kpoints : Kpoints = Field (description = "The actual k-points used in the calculation." )
97- parameters : dict [ str , Any ] = Field (description = "The default-padded input parameters interpreted by VASP." )
151+ final_structure : StructureType = Field (description = "The final structure." )
152+ kpoints : KpointsType = Field (description = "The actual k-points used in the calculation." )
153+ parameters : IncarType = Field (description = "The default-padded input parameters interpreted by VASP." )
98154 bandgap : float = Field (description = "The bandgap - note that this field is derived from the Vasprun object." )
99- potcar_symbols : Optional [list [str ]] = Field (
155+ ionic_steps : list [LightIonicStep ] = Field ([], description = "The ionic steps in the calculation." )
156+ potcar_symbols : list [str ] | None = Field (
100157 None ,
101158 description = "Optional: if a POTCAR is unavailable, this is used to determine the functional used in the calculation." ,
102159 )
@@ -119,45 +176,18 @@ def from_vasprun(cls, vasprun: Vasprun) -> Self:
119176 bandgap = vasprun .get_band_structure (efermi = "smart" ).get_band_gap ()["energy" ],
120177 )
121178
122- @model_serializer
123- def deserialize_objects (self ) -> dict [str , Any ]:
124- """Ensure all pymatgen objects are deserialized."""
125- model_dumped = {k : getattr (self , k ) for k in self .__class__ .model_fields }
126- for k in ("final_structure" , "kpoints" ):
127- model_dumped [k ] = model_dumped [k ].as_dict ()
128- for iion , istep in enumerate (model_dumped ["ionic_steps" ]):
129- if (istruct := istep .get ("structure" )) and isinstance (istruct , Structure ):
130- model_dumped ["ionic_steps" ][iion ]["structure" ] = istruct .as_dict ()
131- for k in ("forces" , "stress" ):
132- if (val := istep .get (k )) is not None and isinstance (val , np .ndarray ):
133- model_dumped ["ionic_steps" ][iion ][k ] = val .tolist ()
134- return model_dumped
135-
136179
137180class VaspInputSafe (BaseModel ):
138181 """Stricter VaspInputSet with no POTCAR info."""
139182
140- incar : Incar = Field (description = "The INCAR used in the calculation." )
141- structure : Structure = Field (description = "The structure associated with the calculation." )
142- kpoints : Optional [Kpoints ] = Field (None , description = "The optional KPOINTS or IBZKPT file used in the calculation." )
143- potcar : Optional [list [PotcarSummaryStats ]] = Field (None , description = "The optional POTCAR used in the calculation." )
144- potcar_functional : Optional [str ] = Field (None , description = "The pymatgen-labelled POTCAR library release." )
145- _pmg_vis : Optional [VaspInputSet ] = PrivateAttr (None )
146-
147- @model_serializer
148- def deserialize_objects (self ) -> dict [str , Any ]:
149- """Ensure all pymatgen objects are deserialized."""
150- model_dumped : dict [str , Any ] = {}
151- if self .potcar :
152- model_dumped ["potcar" ] = [p .model_dump () for p in self .potcar ]
153- for k in (
154- "incar" ,
155- "structure" ,
156- "kpoints" ,
157- ):
158- if pmg_obj := getattr (self , k ):
159- model_dumped [k ] = pmg_obj .as_dict ()
160- return model_dumped
183+ incar : IncarType = Field (description = "The INCAR used in the calculation." )
184+ structure : StructureType = Field (description = "The structure associated with the calculation." )
185+ kpoints : KpointsType | None = Field (
186+ None , description = "The optional KPOINTS or IBZKPT file used in the calculation."
187+ )
188+ potcar : list [PotcarSummaryStats ] | None = Field (None , description = "The optional POTCAR used in the calculation." )
189+ potcar_functional : str | None = Field (None , description = "The pymatgen-labelled POTCAR library release." )
190+ _pmg_vis : VaspInputSet | None = PrivateAttr (None )
161191
162192 @classmethod
163193 def from_vasp_input_set (cls , vis : VaspInputSet ) -> Self :
0 commit comments