11"""Check common issues with VASP calculations."""
22
33from __future__ import annotations
4- from dataclasses import dataclass , field
4+ from pydantic import Field , field_validator
55import numpy as np
6-
7- from typing import TYPE_CHECKING
6+ from functools import cached_property
7+ from typing import Any
88
99from emmet .core .vasp .calc_types .enums import TaskType
1010from pymatgen .core import Structure
1111
1212from pymatgen .io .validation .common import BaseValidator
1313
14- if TYPE_CHECKING :
15- from emmet .core .tasks import TaskDoc
16- from emmet .core .vasp .calc_types .enums import RunType
17- from emmet .core .vasp .task_valid import TaskDocument
18- from pymatgen .io .vasp .inputs import Incar
19- from typing import Sequence
20- from numpy .typing import ArrayLike
14+ from emmet .core .tasks import TaskDoc
15+ from emmet .core .vasp .calc_types .enums import RunType
16+ from emmet .core .vasp .task_valid import TaskDocument
17+ from pymatgen .io .vasp .inputs import Incar
18+ from collections .abc import Sequence
19+ from numpy .typing import ArrayLike
2120
2221
23- @dataclass
2422class CheckCommonErrors (BaseValidator ):
2523 """
2624 Check for common calculation errors.
@@ -59,22 +57,33 @@ class CheckCommonErrors(BaseValidator):
5957
6058 reasons : list [str ]
6159 warnings : list [str ]
62- task_doc : TaskDoc | TaskDocument = None
60+ task_doc : dict = None
6361 parameters : dict = None
6462 structure : Structure = None
6563 run_type : RunType = None
6664 name : str = "Check common errors"
6765 fast : bool = False
6866 defaults : dict | None = None
6967 # TODO: make this also work for elements Gd and Eu, which have magmoms >5 in at least one of their pure structures
70- valid_max_magmoms : dict [str , float ] = field (default_factory = lambda : {"Gd" : 10.0 , "Eu" : 10.0 })
71- exclude_elements : set [str ] = field (default_factory = lambda : {"Am" , "Po" })
68+ valid_max_magmoms : dict [str , float ] = Field (default_factory = lambda : {"Gd" : 10.0 , "Eu" : 10.0 })
69+ exclude_elements : set [str ] = Field (default_factory = lambda : {"Am" , "Po" })
7270 valid_max_allowed_scf_gradient : float | None = None
7371 num_ionic_steps_to_avg_drift_over : int | None = None
7472
75- def __post_init__ (self ):
76- self .incar = self .task_doc ["calcs_reversed" ][0 ]["input" ]["incar" ]
77- self .ionic_steps = self .task_doc ["calcs_reversed" ][0 ]["output" ]["ionic_steps" ]
73+ @field_validator ("task_doc" ,mode = "before" )
74+ @classmethod
75+ def deserialize_task_doc (cls ,val : Any ) -> dict :
76+ if hasattr (val ,"model_dump" ):
77+ return val .model_dump ()
78+ return val
79+
80+ @cached_property
81+ def incar (self ) -> dict :
82+ return self .task_doc ["calcs_reversed" ][0 ]["input" ]["incar" ]
83+
84+ @cached_property
85+ def ionic_steps (self ) -> list :
86+ return self .task_doc ["calcs_reversed" ][0 ]["output" ]["ionic_steps" ]
7887
7988 def _check_run_type (self ) -> None :
8089 if f"{ self .run_type } " .upper () not in {"GGA" , "GGA+U" , "PBE" , "PBE+U" , "R2SCAN" }:
@@ -214,8 +223,6 @@ def _check_unused_elements(self) -> None:
214223 "which are not currently being accepted."
215224 )
216225
217-
218- @dataclass
219226class CheckVaspVersion (BaseValidator ):
220227 """
221228 Check for common errors related to the version of VASP used.
@@ -280,27 +287,33 @@ def _check_vasp_version(self) -> None:
280287 "but we only allow versions 5.4.4 and >=6.0.0 (as of July 2023)."
281288 )
282289
283-
284- @dataclass
285290class CheckStructureProperties (BaseValidator ):
286291 """Check structure for options that are not suitable for thermodynamic calculations."""
287292
288293 reasons : list [str ]
289294 warnings : list [str ]
290- structures : list [dict | Structure | None ] = None
295+ structures : list [Structure ]
291296 task_type : TaskType = None
292297 name : str = "VASP POSCAR properties validator"
293298 site_properties_to_check : tuple [str , ...] = ("selective_dynamics" , "velocities" )
294299
295- def __post_init__ (self ) -> None :
300+ @field_validator ("structures" ,mode = "before" )
301+ @classmethod
302+ def serialize_structures (cls , val : list [Structure | dict | None ]) -> list [Structure ]:
296303 """Extract required structure site properties."""
297-
298- for idx , struct in enumerate (self .structures ):
299- if isinstance (struct , dict ):
300- self .structures [idx ] = Structure .from_dict (struct )
301-
302- self ._site_props = {
303- k : [struct .site_properties .get (k ) for struct in self .structures if struct ] # type: ignore[union-attr]
304+
305+ out_val = []
306+ for struct in val :
307+ if struct :
308+ if isinstance (struct , dict ):
309+ struct = Structure .from_dict (struct )
310+ out_val .append (struct )
311+ return out_val
312+
313+ @cached_property
314+ def site_properties (self ) -> dict [str , Any ]:
315+ return {
316+ k : [struct .site_properties .get (k ) for struct in self .structures ]
304317 for k in self .site_properties_to_check
305318 }
306319
@@ -314,7 +327,7 @@ def _has_frozen_degrees_of_freedom(selective_dynamics_array: ArrayLike[bool] | N
314327 def _check_selective_dynamics (self ) -> None :
315328 """Check structure for inappropriate site properties."""
316329
317- if (selec_dyn := self ._site_props .get ("selective_dynamics" )) is not None and self .task_type in {
330+ if (selec_dyn := self .site_properties .get ("selective_dynamics" )) is not None and self .task_type in {
318331 TaskType .Structure_Optimization ,
319332 TaskType .Deformation ,
320333 }:
@@ -334,7 +347,7 @@ def _has_nonzero_velocities(velocities: ArrayLike | None, tol: float = 1.0e-8) -
334347 def _check_velocities (self ) -> None :
335348 """Check structure for non-zero velocities."""
336349
337- if (velos := self ._site_props .get ("velocities" )) is not None and self .task_type != TaskType .Molecular_Dynamics :
350+ if (velos := self .site_properties .get ("velocities" )) is not None and self .task_type != TaskType .Molecular_Dynamics :
338351 if any (self ._has_nonzero_velocities (velo ) for velo in velos ):
339352 self .warnings .append (
340353 "At least one of the structures had non-zero velocities. "
0 commit comments