Skip to content

Commit ca7f3d7

Browse files
begin refactor to isolate pmg / emmet dependences
1 parent c14f403 commit ca7f3d7

File tree

8 files changed

+1079
-272
lines changed

8 files changed

+1079
-272
lines changed

config.yaml.gz

32 Bytes
Binary file not shown.

pymatgen/io/validation/check_common_errors.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
"""Check common issues with VASP calculations."""
22

33
from __future__ import annotations
4-
from dataclasses import dataclass, field
4+
from pydantic import Field, field_validator
55
import numpy as np
6-
7-
from typing import TYPE_CHECKING
6+
from functools import cached_property
7+
from typing import Any
88

99
from emmet.core.vasp.calc_types.enums import TaskType
1010
from pymatgen.core import Structure
1111

1212
from pymatgen.io.validation.common import BaseValidator
1313

14-
if TYPE_CHECKING:
15-
from emmet.core.tasks import TaskDoc
16-
from emmet.core.vasp.calc_types.enums import RunType
17-
from emmet.core.vasp.task_valid import TaskDocument
18-
from pymatgen.io.vasp.inputs import Incar
19-
from typing import Sequence
20-
from numpy.typing import ArrayLike
14+
from emmet.core.tasks import TaskDoc
15+
from emmet.core.vasp.calc_types.enums import RunType
16+
from emmet.core.vasp.task_valid import TaskDocument
17+
from pymatgen.io.vasp.inputs import Incar
18+
from collections.abc import Sequence
19+
from numpy.typing import ArrayLike
2120

2221

23-
@dataclass
2422
class CheckCommonErrors(BaseValidator):
2523
"""
2624
Check for common calculation errors.
@@ -59,22 +57,33 @@ class CheckCommonErrors(BaseValidator):
5957

6058
reasons: list[str]
6159
warnings: list[str]
62-
task_doc: TaskDoc | TaskDocument = None
60+
task_doc: dict = None
6361
parameters: dict = None
6462
structure: Structure = None
6563
run_type: RunType = None
6664
name: str = "Check common errors"
6765
fast: bool = False
6866
defaults: dict | None = None
6967
# TODO: make this also work for elements Gd and Eu, which have magmoms >5 in at least one of their pure structures
70-
valid_max_magmoms: dict[str, float] = field(default_factory=lambda: {"Gd": 10.0, "Eu": 10.0})
71-
exclude_elements: set[str] = field(default_factory=lambda: {"Am", "Po"})
68+
valid_max_magmoms: dict[str, float] = Field(default_factory=lambda: {"Gd": 10.0, "Eu": 10.0})
69+
exclude_elements: set[str] = Field(default_factory=lambda: {"Am", "Po"})
7270
valid_max_allowed_scf_gradient: float | None = None
7371
num_ionic_steps_to_avg_drift_over: int | None = None
7472

75-
def __post_init__(self):
76-
self.incar = self.task_doc["calcs_reversed"][0]["input"]["incar"]
77-
self.ionic_steps = self.task_doc["calcs_reversed"][0]["output"]["ionic_steps"]
73+
@field_validator("task_doc",mode="before")
74+
@classmethod
75+
def deserialize_task_doc(cls,val : Any) -> dict:
76+
if hasattr(val,"model_dump"):
77+
return val.model_dump()
78+
return val
79+
80+
@cached_property
81+
def incar(self) -> dict:
82+
return self.task_doc["calcs_reversed"][0]["input"]["incar"]
83+
84+
@cached_property
85+
def ionic_steps(self) -> list:
86+
return self.task_doc["calcs_reversed"][0]["output"]["ionic_steps"]
7887

7988
def _check_run_type(self) -> None:
8089
if f"{self.run_type}".upper() not in {"GGA", "GGA+U", "PBE", "PBE+U", "R2SCAN"}:
@@ -214,8 +223,6 @@ def _check_unused_elements(self) -> None:
214223
"which are not currently being accepted."
215224
)
216225

217-
218-
@dataclass
219226
class CheckVaspVersion(BaseValidator):
220227
"""
221228
Check for common errors related to the version of VASP used.
@@ -280,27 +287,33 @@ def _check_vasp_version(self) -> None:
280287
"but we only allow versions 5.4.4 and >=6.0.0 (as of July 2023)."
281288
)
282289

283-
284-
@dataclass
285290
class CheckStructureProperties(BaseValidator):
286291
"""Check structure for options that are not suitable for thermodynamic calculations."""
287292

288293
reasons: list[str]
289294
warnings: list[str]
290-
structures: list[dict | Structure | None] = None
295+
structures: list[Structure]
291296
task_type: TaskType = None
292297
name: str = "VASP POSCAR properties validator"
293298
site_properties_to_check: tuple[str, ...] = ("selective_dynamics", "velocities")
294299

295-
def __post_init__(self) -> None:
300+
@field_validator("structures",mode="before")
301+
@classmethod
302+
def serialize_structures(cls, val : list[Structure | dict | None]) -> list[Structure]:
296303
"""Extract required structure site properties."""
297-
298-
for idx, struct in enumerate(self.structures):
299-
if isinstance(struct, dict):
300-
self.structures[idx] = Structure.from_dict(struct)
301-
302-
self._site_props = {
303-
k: [struct.site_properties.get(k) for struct in self.structures if struct] # type: ignore[union-attr]
304+
305+
out_val = []
306+
for struct in val:
307+
if struct:
308+
if isinstance(struct, dict):
309+
struct = Structure.from_dict(struct)
310+
out_val.append(struct)
311+
return out_val
312+
313+
@cached_property
314+
def site_properties(self) -> dict[str, Any]:
315+
return {
316+
k: [struct.site_properties.get(k) for struct in self.structures]
304317
for k in self.site_properties_to_check
305318
}
306319

@@ -314,7 +327,7 @@ def _has_frozen_degrees_of_freedom(selective_dynamics_array: ArrayLike[bool] | N
314327
def _check_selective_dynamics(self) -> None:
315328
"""Check structure for inappropriate site properties."""
316329

317-
if (selec_dyn := self._site_props.get("selective_dynamics")) is not None and self.task_type in {
330+
if (selec_dyn := self.site_properties.get("selective_dynamics")) is not None and self.task_type in {
318331
TaskType.Structure_Optimization,
319332
TaskType.Deformation,
320333
}:
@@ -334,7 +347,7 @@ def _has_nonzero_velocities(velocities: ArrayLike | None, tol: float = 1.0e-8) -
334347
def _check_velocities(self) -> None:
335348
"""Check structure for non-zero velocities."""
336349

337-
if (velos := self._site_props.get("velocities")) is not None and self.task_type != TaskType.Molecular_Dynamics:
350+
if (velos := self.site_properties.get("velocities")) is not None and self.task_type != TaskType.Molecular_Dynamics:
338351
if any(self._has_nonzero_velocities(velo) for velo in velos):
339352
self.warnings.append(
340353
"At least one of the structures had non-zero velocities. "

pymatgen/io/validation/check_incar.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
"""Validate VASP INCAR files."""
22

33
from __future__ import annotations
4+
from collections.abc import Sequence
45
import copy
5-
from dataclasses import dataclass
66
import numpy as np
77
from emmet.core.vasp.calc_types.enums import TaskType
88

9-
from pymatgen.io.validation.common import BaseValidator, BasicValidator
10-
from pymatgen.io.validation.vasp_defaults import InputCategory, VaspParam
9+
from pymatgen.core import Structure
10+
from pymatgen.io.vasp.sets import VaspInputSet
1111

12-
from typing import TYPE_CHECKING
12+
from pymatgen.io.validation.common import BaseValidator
13+
from pymatgen.io.validation.vasp_defaults import InputCategory, VaspParam
1314

14-
if TYPE_CHECKING:
15-
from typing import Any, Sequence
16-
from pymatgen.core import Structure
17-
from pymatgen.io.vasp.sets import VaspInputSet
15+
from typing import Any
1816

1917
# TODO: fix ISIF getting overwritten by MP input set.
2018

21-
22-
@dataclass
2319
class CheckIncar(BaseValidator):
2420
"""
2521
Check calculation parameters related to INCAR input tags.
@@ -107,25 +103,16 @@ def check(self) -> None:
107103
working_params.update_parameters_and_defaults()
108104

109105
# Validate each parameter in the set of working parameters
110-
simple_validator = BasicValidator()
111-
for key in working_params.defaults:
106+
for key, vasp_param in working_params.defaults.items():
112107
if self.fast and len(self.reasons) > 0:
113108
# fast check: stop checking whenever a single check fails
114109
break
115110

116-
simple_validator.check_parameter(
117-
reasons=self.reasons,
118-
warnings=self.warnings,
119-
input_tag=working_params.defaults[key]["alias"],
120-
current_values=working_params.parameters[key],
121-
reference_values=working_params.valid_values[key],
122-
operations=working_params.defaults[key]["operation"],
123-
tolerance=working_params.defaults[key]["tolerance"],
124-
append_comments=working_params.defaults[key]["comment"],
125-
severity=working_params.defaults[key]["severity"],
111+
vasp_param.check(
112+
working_params.parameters[key],
113+
working_params.valid_values[key]
126114
)
127-
128-
115+
129116
class UpdateParameterValues:
130117
"""
131118
Update a set of parameters according to supplied rules and defaults.

pymatgen/io/validation/check_kpoints_kspacing.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
"""Validate VASP KPOINTS files or the KSPACING/KGAMMA INCAR settings."""
22

33
from __future__ import annotations
4-
from dataclasses import dataclass
54
import numpy as np
65
from pymatgen.io.vasp import Kpoints
76

8-
from pymatgen.io.validation.common import BaseValidator
9-
10-
from typing import TYPE_CHECKING
11-
12-
if TYPE_CHECKING:
13-
from pymatgen.core import Structure
14-
from pymatgen.io.vasp.sets import VaspInputSet
7+
from pymatgen.core import Structure
8+
from pymatgen.io.vasp.sets import VaspInputSet
159

10+
from pymatgen.io.validation.common import BaseValidator
1611

17-
@dataclass
1812
class CheckKpointsKspacing(BaseValidator):
1913
"""
2014
Check that k-point density is sufficiently high and is compatible with lattice symmetry.
@@ -54,7 +48,7 @@ class CheckKpointsKspacing(BaseValidator):
5448
warnings: list[str]
5549
name: str = "Check k-point density"
5650
valid_input_set: VaspInputSet = None
57-
kpoints: Kpoints | dict = None
51+
kpoints: Kpoints = None
5852
structure: Structure = None
5953
defaults: dict | None = None
6054
kpts_tolerance: float | None = None
@@ -85,13 +79,13 @@ def _get_valid_num_kpts(self) -> int:
8579

8680
def _check_user_shifted_mesh(self) -> None:
8781
# Check for user shifts
88-
if (not self.allow_kpoint_shifts) and any(shift_val != 0 for shift_val in self.kpoints["usershift"]):
82+
if (not self.allow_kpoint_shifts) and any(shift_val != 0 for shift_val in self.kpoints.kpts_shift):
8983
self.reasons.append("INPUT SETTINGS --> KPOINTS: shifting the kpoint mesh is not currently allowed.")
9084

9185
def _check_explicit_mesh_permitted(self) -> None:
9286
# Check for explicit kpoint meshes
9387

94-
if (not self.allow_explicit_kpoint_mesh) and len(self.kpoints["kpoints"]) > 1:
88+
if (not self.allow_explicit_kpoint_mesh) and len(self.kpoints.kpts) > 1:
9589
self.reasons.append(
9690
"INPUT SETTINGS --> KPOINTS: explicitly defining "
9791
"the k-point mesh is not currently allowed. "
@@ -106,13 +100,10 @@ def _check_kpoint_density(self) -> None:
106100
# Check number of kpoints used
107101
valid_num_kpts = self._get_valid_num_kpts()
108102

109-
if isinstance(self.kpoints, Kpoints):
110-
self.kpoints = self.kpoints.as_dict()
111-
112103
cur_num_kpts = max(
113-
self.kpoints.get("nkpoints", 0),
114-
np.prod(self.kpoints.get("kpoints")),
115-
len(self.kpoints.get("kpoints")),
104+
self.kpoints.num_kpts,
105+
np.prod(self.kpoints.kpts),
106+
len(self.kpoints.kpts),
116107
)
117108
if cur_num_kpts < valid_num_kpts:
118109
self.reasons.append(
@@ -123,17 +114,17 @@ def _check_kpoint_density(self) -> None:
123114
def _check_kpoint_mesh_symmetry(self) -> None:
124115
# check for valid kpoint mesh (which depends on symmetry of the structure)
125116

126-
cur_kpoint_style = self.kpoints.get("generation_style").lower()
117+
cur_kpoint_style = self.kpoints.style.name.lower()
127118
is_hexagonal = self.structure.lattice.is_hexagonal()
128119
is_face_centered = self.structure.get_space_group_info()[0][0] == "F"
129120
monkhorst_mesh_is_invalid = is_hexagonal or is_face_centered
130121
if (
131122
cur_kpoint_style == "monkhorst"
132123
and monkhorst_mesh_is_invalid
133-
and any(x % 2 == 0 for x in self.kpoints.get("kpoints")[0])
124+
and any(x % 2 == 0 for x in self.kpoints.kpts[0])
134125
):
135126
# only allow Monkhorst with all odd number of subdivisions per axis.
136-
kx, ky, kz = self.kpoints.get("kpoints")[0]
127+
kx, ky, kz = self.kpoints.kpts[0]
137128
self.reasons.append(
138129
f"INPUT SETTINGS --> KPOINTS or KGAMMA: ({kx}x{ky}x{kz}) "
139130
"Monkhorst-Pack kpoint mesh was used."

pymatgen/io/validation/check_potcar.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
"""Check POTCAR against known POTCARs in pymatgen, without setting up psp_resources."""
22

33
from __future__ import annotations
4-
from dataclasses import dataclass, field
4+
from pydantic import Field
55
from importlib.resources import files as import_resource_files
66
from monty.serialization import loadfn
77
import numpy as np
88

9-
from pymatgen.io.validation.common import BaseValidator
10-
11-
from typing import TYPE_CHECKING
9+
from pymatgen.core import Structure
10+
from pymatgen.io.vasp.sets import VaspInputSet
1211

13-
if TYPE_CHECKING:
14-
from pymatgen.core import Structure
15-
from pymatgen.io.vasp.sets import VaspInputSet
12+
from pymatgen.io.validation.common import BaseValidator
1613

1714
_potcar_summary_stats = loadfn(import_resource_files("pymatgen.io.vasp") / "potcar-summary-stats.json.bz2")
1815

19-
20-
@dataclass
2116
class CheckPotcar(BaseValidator):
2217
"""
2318
Check POTCAR against library of known valid POTCARs.
@@ -52,9 +47,9 @@ class CheckPotcar(BaseValidator):
5247
warnings: list[str]
5348
valid_input_set: VaspInputSet = None
5449
structure: Structure = None
55-
potcars: dict = None
50+
potcars: list[dict] = None
5651
name: str = "Check POTCARs"
57-
potcar_summary_stats: dict = field(default_factory=lambda: _potcar_summary_stats)
52+
potcar_summary_stats: dict = Field(default_factory=lambda: _potcar_summary_stats)
5853
data_match_tol: float = 1.0e-6
5954
fast: bool = False
6055

pymatgen/io/validation/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Common class constructor for validation checks."""
22

33
from __future__ import annotations
4-
from dataclasses import dataclass
54
from math import isclose
65
from typing import TYPE_CHECKING, Literal
76

7+
from pydantic import BaseModel
8+
89
if TYPE_CHECKING:
910
from typing import Any
1011

@@ -214,8 +215,7 @@ def check_parameter(
214215
)
215216

216217

217-
@dataclass
218-
class BaseValidator:
218+
class BaseValidator(BaseModel):
219219
"""
220220
Template for validation classes.
221221
@@ -261,5 +261,5 @@ def check(self) -> None:
261261
if self.fast and len(self.reasons) > 0:
262262
# fast check: stop checking whenever a single check fails
263263
break
264-
264+
265265
getattr(self, attr)()

0 commit comments

Comments
 (0)