Skip to content

Commit 78b5800

Browse files
refactor defaults; define class with base defaults
1 parent 45c622a commit 78b5800

File tree

7 files changed

+408
-546
lines changed

7 files changed

+408
-546
lines changed

pymatgen/io/validation/check_incar.py

Lines changed: 15 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from __future__ import annotations
44
import copy
55
from dataclasses import dataclass
6-
from math import isclose
76
import numpy as np
87
from emmet.core.vasp.calc_types.enums import TaskType
98

10-
from pymatgen.io.validation.common import BaseValidator
9+
from pymatgen.io.validation.common import BaseValidator, BasicValidator
10+
from pymatgen.io.validation.vasp_defaults import InputCategory
1111

12-
from typing import TYPE_CHECKING, Literal
12+
from typing import TYPE_CHECKING
1313

1414
if TYPE_CHECKING:
1515
from typing import Any, Sequence
@@ -77,9 +77,9 @@ def check(self) -> None:
7777
Check calculation parameters related to INCAR input tags.
7878
7979
This first updates any parameter with a specified update method.
80-
In practice, each INCAR tag in `vasp_defaults.yaml` has a "tag"
81-
attribute. If there is an update method
82-
`UpdateParameterValues.update_{tag.replace(" ","_")}_params`,
80+
In practice, each INCAR tag in `VASP` has a "tag" attribute.
81+
If there is an update method
82+
`UpdateParameterValues.update_{tag}_params`,
8383
all parameters with that tag will be updated.
8484
8585
Then after all missing values in the supplied parameters (padding
@@ -116,7 +116,7 @@ def check(self) -> None:
116116
simple_validator.check_parameter(
117117
reasons=self.reasons,
118118
warnings=self.warnings,
119-
input_tag=working_params.defaults[key].get("alias", key),
119+
input_tag=working_params.defaults[key].get("alias") or key,
120120
current_values=working_params.parameters[key],
121121
reference_values=working_params.valid_values[key],
122122
operations=working_params.defaults[key]["operation"],
@@ -125,6 +125,8 @@ def check(self) -> None:
125125
severity=working_params.defaults[key]["severity"],
126126
)
127127

128+
if key == "LCHIMAG":
129+
print(self.reasons)
128130

129131
class UpdateParameterValues:
130132
"""
@@ -141,9 +143,9 @@ class UpdateParameterValues:
141143
This class allows one to mimic the VASP NBANDS functionality for computing
142144
NBANDS dynamically, and update both the current and reference values for NBANDs.
143145
144-
To do this in a simple, automatic fashion, each parameter in `vasp_defaults.yaml` has
146+
To do this in a simple, automatic fashion, each parameter in `VASP_DEFAULTS` has
145147
a "tag" field. To update a set of parameters with a given tag, one then adds a function
146-
to `GetParams` called `update_{tag.replace(" ","_")}_params`. For example, the "dft plus u"
148+
to `GetParams` called `update_{tag}_params`. For example, the "dft plus u"
147149
tag has an update function called `update_dft_plus_u_params`. If no such update method
148150
exists, that tag is skipped.
149151
@@ -201,7 +203,7 @@ def __init__(
201203
"""
202204

203205
self.parameters = copy.deepcopy(parameters)
204-
self.defaults = copy.deepcopy(defaults)
206+
self.defaults = {k: v.__dict__() for k, v in defaults.items()}
205207
self.input_set = input_set
206208
self.vasp_version = vasp_version
207209
self.structure = structure
@@ -220,17 +222,14 @@ def __init__(
220222
def update_parameters_and_defaults(self) -> None:
221223
"""Update user parameters and defaults for tags with a specified update method."""
222224

223-
self.categories: dict[str, list[str]] = {}
225+
self.categories: dict[str, list[str]] = {tag: [] for tag in InputCategory.__members__}
224226
for key in self.defaults:
225-
if self.defaults[key]["tag"] not in self.categories:
226-
self.categories[self.defaults[key]["tag"]] = []
227227
self.categories[self.defaults[key]["tag"]].append(key)
228228

229-
tag_order = [key.replace(" ", "_") for key in self.categories if key != "post_init"] + ["post_init"]
230229
# add defaults to parameters from the incar as needed
231230
self.add_defaults_to_parameters(valid_values_source=self.input_set.incar)
232231
# collect list of tags in parameter defaults
233-
for tag in tag_order:
232+
for tag in InputCategory.__members__:
234233
# check to see if update method for that tag exists, and if so, run it
235234
update_method_str = f"update_{tag}_params"
236235
if hasattr(self, update_method_str):
@@ -264,7 +263,7 @@ def update_dft_plus_u_params(self) -> None:
264263
if not self.parameters["LDAU"]:
265264
return
266265

267-
for key in self.categories["dft plus u"]:
266+
for key in self.categories["dft_plus_u"]:
268267
valid_value = self.input_set.incar.get(key, self.defaults[key]["value"])
269268

270269
# TODO: ADK: is LDAUTYPE usually specified as a list??
@@ -785,198 +784,3 @@ def update_post_init_params(self):
785784
4.0 * self.defaults["NBANDS"]["value"]
786785
)
787786
self.defaults["EBREAK"]["operation"] = "auto fail"
788-
789-
790-
class BasicValidator:
791-
"""
792-
Compare test and reference values according to one or more operations.
793-
794-
Parameters
795-
-----------
796-
global_tolerance : float = 1.e-4
797-
Default tolerance for assessing approximate equality via math.isclose
798-
799-
Attrs
800-
-----------
801-
operations : set[str]
802-
List of acceptable operations, such as "==" for strict equality, or "in" to
803-
check if a Sequence contains an element
804-
"""
805-
806-
# avoiding dunder methods because these raise too many NotImplemented's
807-
operations: set[str | None] = {
808-
"==",
809-
">",
810-
">=",
811-
"<",
812-
"<=",
813-
"in",
814-
"approx",
815-
"auto fail",
816-
None,
817-
}
818-
819-
def __init__(self, global_tolerance=1.0e-4) -> None:
820-
"""Set math.isclose tolerance"""
821-
self.tolerance = global_tolerance
822-
823-
def _comparator(self, lhs: Any, operation: str, rhs: Any, **kwargs) -> bool:
824-
"""
825-
Compare different values using one of a set of supported operations in self.operations.
826-
827-
Parameters
828-
-----------
829-
lhs : Any
830-
Left-hand side of the operation.
831-
operation : str
832-
Operation acting on rhs from lhs. For example, if operation is ">",
833-
this returns (lhs > rhs).
834-
rhs : Any
835-
Right-hand of the operation.
836-
kwargs
837-
If needed, kwargs to pass to operation.
838-
"""
839-
if operation is None:
840-
c = True
841-
elif operation == "auto fail":
842-
c = False
843-
elif operation == "==":
844-
c = lhs == rhs
845-
elif operation == ">":
846-
c = lhs > rhs
847-
elif operation == ">=":
848-
c = lhs >= rhs
849-
elif operation == "<":
850-
c = lhs < rhs
851-
elif operation == "<=":
852-
c = lhs <= rhs
853-
elif operation == "in":
854-
c = lhs in rhs
855-
elif operation == "approx":
856-
c = isclose(lhs, rhs, **kwargs)
857-
return c
858-
859-
def _check_parameter(
860-
self,
861-
error_list: list[str],
862-
input_tag: str,
863-
current_value: Any,
864-
reference_value: Any,
865-
operation: str,
866-
tolerance: float | None = None,
867-
append_comments: str | None = None,
868-
) -> None:
869-
"""
870-
Determine validity of parameter subject to a single specified operation.
871-
872-
Parameters
873-
-----------
874-
error_list : list[str]
875-
A list of error/warning strings to update if a check fails.
876-
input_tag : str
877-
The name of the input tag which is being checked.
878-
current_value : Any
879-
The test value.
880-
reference_value : Any
881-
The value to compare the test value to.
882-
operation : str
883-
A valid operation in self.operations. For example, if operation = "<=",
884-
this checks `current_value <= reference_value` (note order of values).
885-
tolerance : float or None (default)
886-
If None and operation == "approx", default tolerance to self.tolerance.
887-
Otherwise, use the user-supplied tolerance.
888-
append_comments : str or None (default)
889-
Additional comments that may be helpful for the user to understand why
890-
a check failed.
891-
"""
892-
893-
append_comments = append_comments or ""
894-
895-
if isinstance(current_value, str):
896-
current_value = current_value.upper()
897-
898-
kwargs: dict[str, Any] = {}
899-
if operation == "approx" and isinstance(current_value, float):
900-
kwargs.update({"rel_tol": tolerance or self.tolerance, "abs_tol": 0.0})
901-
valid_value = self._comparator(current_value, operation, reference_value, **kwargs)
902-
903-
if not valid_value:
904-
error_list.append(
905-
f"INPUT SETTINGS --> {input_tag}: is {current_value}, but should be "
906-
f"{'' if operation == 'auto fail' else operation + ' '}{reference_value}."
907-
f"{' ' if len(append_comments) > 0 else ''}{append_comments}"
908-
)
909-
910-
def check_parameter(
911-
self,
912-
reasons: list[str],
913-
warnings: list[str],
914-
input_tag: str,
915-
current_values: Any,
916-
reference_values: Any,
917-
operations: str | list[str],
918-
tolerance: float = None,
919-
append_comments: str | None = None,
920-
severity: Literal["reason", "warning"] = "reason",
921-
) -> None:
922-
"""
923-
Determine validity of parameter according to one or more operations.
924-
925-
Parameters
926-
-----------
927-
reasons : list[str]
928-
A list of error strings to update if a check fails. These are higher
929-
severity and would deprecate a calculation.
930-
warnings : list[str]
931-
A list of warning strings to update if a check fails. These are lower
932-
severity and would flag a calculation for possible review.
933-
input_tag : str
934-
The name of the input tag which is being checked.
935-
current_values : Any
936-
The test value(s). If multiple operations are specified, must be a Sequence
937-
of test values.
938-
reference_values : Any
939-
The value(s) to compare the test value(s) to. If multiple operations are
940-
specified, must be a Sequence of reference values.
941-
operations : str
942-
One or more valid operations in self.operations. For example, if operations = "<=",
943-
this checks `current_values <= reference_values` (note order of values).
944-
Or, if operations == ["<=", ">"], this checks
945-
```
946-
(
947-
(current_values[0] <= reference_values[0])
948-
and (current_values[1] > reference_values[1])
949-
)
950-
```
951-
tolerance : float or None (default)
952-
Tolerance to use in math.isclose if any of operations is "approx". Defaults
953-
to self.tolerance.
954-
append_comments : str or None (default)
955-
Additional comments that may be helpful for the user to understand why
956-
a check failed.
957-
severity : Literal["reason", "warning"]
958-
If a calculation fails, the severity of failure. Directs output to
959-
either reasons or warnings.
960-
"""
961-
962-
severity_to_list = {"reason": reasons, "warning": warnings}
963-
964-
if not isinstance(operations, list):
965-
operations = [operations]
966-
current_values = [current_values]
967-
reference_values = [reference_values]
968-
969-
unknown_operations = {operation for operation in operations if operation not in self.operations}
970-
if len(unknown_operations) > 0:
971-
raise ValueError("Unknown operations:\n " + ", ".join([f"{uo}" for uo in unknown_operations]))
972-
973-
for iop in range(len(operations)):
974-
self._check_parameter(
975-
error_list=severity_to_list[severity],
976-
input_tag=input_tag,
977-
current_value=current_values[iop],
978-
reference_value=reference_values[iop],
979-
operation=operations[iop],
980-
tolerance=tolerance,
981-
append_comments=append_comments,
982-
)

0 commit comments

Comments
 (0)