Skip to content

Commit e8c69da

Browse files
expand VaspParam schema, use in INCAR updating
1 parent 78b5800 commit e8c69da

File tree

6 files changed

+183
-147
lines changed

6 files changed

+183
-147
lines changed

pymatgen/io/validation/check_incar.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from emmet.core.vasp.calc_types.enums import TaskType
88

99
from pymatgen.io.validation.common import BaseValidator, BasicValidator
10-
from pymatgen.io.validation.vasp_defaults import InputCategory
10+
from pymatgen.io.validation.vasp_defaults import InputCategory, VaspParam
1111

1212
from typing import TYPE_CHECKING
1313

@@ -116,7 +116,7 @@ def check(self) -> None:
116116
simple_validator.check_parameter(
117117
reasons=self.reasons,
118118
warnings=self.warnings,
119-
input_tag=working_params.defaults[key].get("alias") or key,
119+
input_tag=working_params.defaults[key]["alias"],
120120
current_values=working_params.parameters[key],
121121
reference_values=working_params.valid_values[key],
122122
operations=working_params.defaults[key]["operation"],
@@ -125,9 +125,6 @@ def check(self) -> None:
125125
severity=working_params.defaults[key]["severity"],
126126
)
127127

128-
if key == "LCHIMAG":
129-
print(self.reasons)
130-
131128
class UpdateParameterValues:
132129
"""
133130
Update a set of parameters according to supplied rules and defaults.
@@ -148,24 +145,8 @@ class UpdateParameterValues:
148145
to `GetParams` called `update_{tag}_params`. For example, the "dft plus u"
149146
tag has an update function called `update_dft_plus_u_params`. If no such update method
150147
exists, that tag is skipped.
151-
152-
Attrs
153-
---------
154-
_default_schema : dict[str,Any]
155-
The schema of an entry in the dict of default values (`self.defaults`).
156-
This pads any missing entries in the set of parameters defaults with
157-
sensible default values.
158148
"""
159149

160-
_default_schema: dict[str, Any] = {
161-
"value": None,
162-
"tag": None,
163-
"operation": None,
164-
"comment": None,
165-
"tolerance": 1.0e-4,
166-
"severity": "reason",
167-
}
168-
169150
def __init__(
170151
self,
171152
parameters: dict[str, Any],
@@ -203,7 +184,7 @@ def __init__(
203184
"""
204185

205186
self.parameters = copy.deepcopy(parameters)
206-
self.defaults = {k: v.__dict__() for k, v in defaults.items()}
187+
self.defaults = copy.deepcopy(defaults)
207188
self.input_set = input_set
208189
self.vasp_version = vasp_version
209190
self.structure = structure
@@ -238,9 +219,9 @@ def update_parameters_and_defaults(self) -> None:
238219
# add defaults to parameters from the defaults as needed
239220
self.add_defaults_to_parameters()
240221

241-
for key in self.defaults:
242-
for attr in self._default_schema:
243-
self.defaults[key][attr] = self.defaults[key].get(attr, self._default_schema[attr])
222+
for key, v in self.defaults.items():
223+
if isinstance(v,dict):
224+
self.defaults[key] = VaspParam(**{"name":key,**v})
244225

245226
def add_defaults_to_parameters(self, valid_values_source: dict | None = None) -> None:
246227
"""
@@ -351,11 +332,12 @@ def update_precision_params(self) -> None:
351332
"HIGH": -4e-4,
352333
}
353334
self.parameters["ROPT"] = [abs(value) for value in self.parameters.get("ROPT", [ropt_default[cur_prec]])]
354-
self.defaults["ROPT"] = {
355-
"value": [abs(ropt_default[cur_prec]) for _ in self.parameters["ROPT"]],
356-
"tag": "startup",
357-
"operation": ["<=" for _ in self.parameters["ROPT"]],
358-
}
335+
self.defaults["ROPT"] = VaspParam(
336+
name = "ROPT",
337+
value = [abs(ropt_default[cur_prec]) for _ in self.parameters["ROPT"]],
338+
tag = "startup",
339+
operation = ["<=" for _ in self.parameters["ROPT"]],
340+
)
359341

360342
def update_misc_special_params(self) -> None:
361343
"""Update miscellaneous parameters that do not fall into another category."""
@@ -383,11 +365,7 @@ def update_misc_special_params(self) -> None:
383365

384366
# LCORR.
385367
if self.parameters["IALGO"] != 58:
386-
self.defaults["LCORR"].update(
387-
{
388-
"operation": "==",
389-
}
390-
)
368+
self.defaults["LCORR"]["operation"] = "=="
391369

392370
if (
393371
self.parameters["ISPIN"] == 2
@@ -470,15 +448,16 @@ def update_fft_params(self) -> None:
470448
for key in grid_keys:
471449
self.valid_values[key] = int(self.valid_values[key] * self._fft_grid_tolerance)
472450

473-
self.defaults[key] = {
474-
"value": self.valid_values[key],
475-
"tag": "fft",
476-
"operation": ">=",
477-
"comment": (
451+
self.defaults[key] = VaspParam(
452+
name = key,
453+
value = self.valid_values[key],
454+
tag = "fft",
455+
operation = ">=",
456+
comment=(
478457
"This likely means the number FFT grid points was modified by the user. "
479458
"If not, please create a GitHub issue."
480459
),
481-
}
460+
)
482461

483462
def update_density_mixing_params(self) -> None:
484463
"""
@@ -586,16 +565,17 @@ def update_smearing_params(self, bandgap_tol=1.0e-4) -> None:
586565
self.parameters["ELECTRONIC ENTROPY"] = round(self.parameters["ELECTRONIC ENTROPY"] * convert_eV_to_meV, 3)
587566
self.valid_values["ELECTRONIC ENTROPY"] = 0.001 * convert_eV_to_meV
588567

589-
self.defaults["ELECTRONIC ENTROPY"] = {
590-
"value": 0.0,
591-
"tag": "smearing",
592-
"comment": (
568+
self.defaults["ELECTRONIC ENTROPY"] = VaspParam(
569+
name = "ELECTRONIC ENTROPY",
570+
value = 0.0,
571+
tag = "smearing",
572+
comment=(
593573
"The entropy term (T*S) in the energy is suggested to be less than "
594574
f"{round(self.valid_values['ELECTRONIC ENTROPY'], 1)} meV/atom "
595575
f"in the VASP wiki. Thus, SIGMA should be decreased."
596576
),
597-
"operation": "<=",
598-
}
577+
operation = "<=",
578+
)
599579

600580
def _get_default_nbands(self):
601581
"""
@@ -673,16 +653,17 @@ def update_electronic_params(self):
673653

674654
# NBANDS.
675655
min_nbands = int(np.ceil(self._NELECT / 2) + 1)
676-
self.defaults["NBANDS"] = {
677-
"value": self._get_default_nbands(),
678-
"operation": [">=", "<="],
679-
"tag": "electronic",
680-
"comment": (
656+
self.defaults["NBANDS"] = VaspParam(
657+
name = "NBANDS",
658+
value = self._get_default_nbands(),
659+
tag = "electronic",
660+
operation = [">=", "<="],
661+
comment = (
681662
"Too many or too few bands can lead to unphysical electronic structure "
682663
"(see https://github.com/materialsproject/custodian/issues/224 "
683664
"for more context.)"
684-
),
685-
}
665+
)
666+
)
686667
self.valid_values["NBANDS"] = [min_nbands, 4 * self.defaults["NBANDS"]["value"]]
687668
self.parameters["NBANDS"] = [self.parameters["NBANDS"] for _ in range(2)]
688669

@@ -734,10 +715,7 @@ def update_ionic_params(self):
734715
# every MP-compliant input set, but often have comparable or even better results) will also be accepted
735716
# I am **NOT** confident that this should be the final check. Perhaps I need convincing (or perhaps it does indeed need to be changed...)
736717
# TODO: -somehow identify if a material is a vdW structure, in which case force-convergence should maybe be more strict?
737-
self.defaults["EDIFFG"] = {
738-
"value": 10 * self.valid_values["EDIFF"],
739-
"category": "ionic",
740-
}
718+
self.defaults["EDIFFG"] = VaspParam(name = "EDIFFG", value = 10 * self.valid_values["EDIFF"], tag = "ionic")
741719

742720
self.valid_values["EDIFFG"] = self.input_set.incar.get("EDIFFG", self.defaults["EDIFFG"]["value"])
743721
self.defaults["EDIFFG"][

pymatgen/io/validation/common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Common class constructor for validation checks."""
2+
23
from __future__ import annotations
34
from dataclasses import dataclass
45
from math import isclose
56
from typing import TYPE_CHECKING, Literal
7+
68
if TYPE_CHECKING:
79
from typing import Any
810

@@ -18,12 +20,15 @@
1820
None,
1921
}
2022

23+
2124
class InvalidOperation(Exception):
2225
"""Define custom exception when checking valid operations."""
23-
def __init__(self, operation : str ) -> None:
26+
27+
def __init__(self, operation: str) -> None:
2428
msg = f"Unknown operation type {operation}; valid values are: {VALID_OPERATIONS}"
2529
super().__init__(msg)
2630

31+
2732
class BasicValidator:
2833
"""
2934
Compare test and reference values according to one or more operations.
@@ -36,7 +41,7 @@ class BasicValidator:
3641

3742
# avoiding dunder methods because these raise too many NotImplemented's
3843

39-
def __init__(self, global_tolerance : float =1.0e-4) -> None:
44+
def __init__(self, global_tolerance: float = 1.0e-4) -> None:
4045
"""Set math.isclose tolerance"""
4146
self.tolerance = global_tolerance
4247

@@ -79,7 +84,6 @@ def _comparator(lhs: Any, operation: str, rhs: Any, **kwargs) -> bool:
7984
raise InvalidOperation(operation)
8085
return c
8186

82-
8387
def _check_parameter(
8488
self,
8589
error_list: list[str],
@@ -164,10 +168,10 @@ def check_parameter(
164168
specified, must be a Sequence of reference values.
165169
operations : str
166170
One or more valid operations in VALID_OPERATIONS.
167-
For example, if operations = "<=", this checks
171+
For example, if operations = "<=", this checks
168172
`current_values <= reference_values`
169173
(note the order of values).
170-
174+
171175
Or, if operations == ["<=", ">"], this checks
172176
```
173177
(
@@ -204,6 +208,7 @@ def check_parameter(
204208
append_comments=append_comments,
205209
)
206210

211+
207212
@dataclass
208213
class BaseValidator:
209214
"""

pymatgen/io/validation/settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Settings for pymatgen-io-validation. Used to be part of EmmetSettings.
55
"""
66

7-
from importlib.resources import files as import_resource_files
87
import json
98
from pathlib import Path
109
from typing import Dict, Type, TypeVar, Union

pymatgen/io/validation/validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pydantic import Field
77
from pydantic.types import ImportString # replacement for PyObject
88
from pathlib import Path
9-
from monty.serialization import loadfn
109

1110
from pymatgen.io.vasp.sets import VaspInputSet
1211

@@ -107,7 +106,7 @@ def from_task_doc(cls, task_doc: TaskDoc | TaskDocument, **kwargs) -> Validation
107106
"""
108107

109108
if isinstance(task_doc, TaskDocument):
110-
task_doc = TaskDoc(**{k : v for k, v in task_doc.model_dump().items() if k != "run_stats"})
109+
task_doc = TaskDoc(**{k: v for k, v in task_doc.model_dump().items() if k != "run_stats"})
111110

112111
return cls.from_dict(jsanitize(task_doc), **kwargs)
113112

0 commit comments

Comments
 (0)