|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 | from pathlib import Path |
5 | | -from pydantic import BaseModel, Field |
| 5 | +from pydantic import BaseModel, Field, PrivateAttr |
6 | 6 | from typing import TYPE_CHECKING |
7 | 7 |
|
8 | 8 | from monty.os.path import zpath |
|
18 | 18 | import os |
19 | 19 | from typing_extensions import Self |
20 | 20 |
|
| 21 | + |
21 | 22 | DEFAULT_CHECKS = [CheckStructureProperties, CheckPotcar, CheckCommonErrors, CheckKpointsKspacing, CheckIncar] |
22 | 23 |
|
23 | 24 | # TODO: check for surface/slab calculations. Especially necessary for external calcs. |
|
27 | 28 | class VaspValidator(BaseModel): |
28 | 29 | """Validate a VASP calculation.""" |
29 | 30 |
|
| 31 | + vasp_files: VaspFiles = Field(description="The VASP I/O.") |
30 | 32 | reasons: list[str] = Field([], description="List of deprecation tags detailing why this task isn't valid") |
31 | 33 | warnings: list[str] = Field([], description="List of warnings about this calculation") |
32 | | - vasp_files: VaspFiles = Field(description="The VASP I/O.") |
| 34 | + |
| 35 | + _validated_md5: str | None = PrivateAttr(None) |
33 | 36 |
|
34 | 37 | @property |
35 | | - def is_valid(self) -> bool: |
36 | | - """Determine if the calculation is valid.""" |
| 38 | + def valid(self) -> bool: |
| 39 | + """Determine if the calculation is valid after ensuring inputs have not changed.""" |
| 40 | + self.recheck() |
37 | 41 | return len(self.reasons) == 0 |
38 | 42 |
|
39 | 43 | @property |
40 | 44 | def has_warnings(self) -> bool: |
41 | 45 | """Determine if any warnings were incurred.""" |
42 | 46 | return len(self.warnings) > 0 |
43 | 47 |
|
| 48 | + def recheck(self) -> None: |
| 49 | + """Rerun validation, prioritizing speed.""" |
| 50 | + new_md5 = None |
| 51 | + if self._validated_md5 is None or (new_md5 := self.vasp_files.md5) != self._validated_md5: |
| 52 | + |
| 53 | + if self.vasp_files.user_input.potcar: |
| 54 | + check_list = DEFAULT_CHECKS |
| 55 | + else: |
| 56 | + check_list = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"] |
| 57 | + self.reasons, self.warnings = self.run_checks(self.vasp_files, check_list=check_list, fast=True) |
| 58 | + self._validated_md5 = new_md5 or self.vasp_files.md5 |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def run_checks( |
| 62 | + vasp_files: VaspFiles, |
| 63 | + check_list: list | tuple = DEFAULT_CHECKS, |
| 64 | + fast: bool = False, |
| 65 | + ) -> tuple[list[str], list[str]]: |
| 66 | + """Perform validation. |
| 67 | +
|
| 68 | + Parameters |
| 69 | + ----------- |
| 70 | + vasp_files : VaspFiles |
| 71 | + The VASP I/O to validate. |
| 72 | + check_list : list or tuple of BaseValidator. |
| 73 | + The list of checks to perform. Defaults to `DEFAULT_CHECKS`. |
| 74 | + fast : bool (default = False) |
| 75 | + Whether to stop validation at the first validation failure (True) |
| 76 | + or compile a list of all failure reasons. |
| 77 | +
|
| 78 | + Returns |
| 79 | + ----------- |
| 80 | + tuple of list of str |
| 81 | + The first list are all reasons for validation failure, |
| 82 | + the second list contains all warnings. |
| 83 | + """ |
| 84 | + reasons: list[str] = [] |
| 85 | + warnings: list[str] = [] |
| 86 | + for check in check_list: |
| 87 | + check(fast=fast).check(vasp_files, reasons, warnings) # type: ignore[arg-type] |
| 88 | + if fast and len(reasons) > 0: |
| 89 | + break |
| 90 | + return reasons, warnings |
| 91 | + |
44 | 92 | @classmethod |
45 | 93 | def from_vasp_input( |
46 | 94 | cls, |
@@ -87,15 +135,14 @@ def from_vasp_input( |
87 | 135 | } |
88 | 136 |
|
89 | 137 | if check_potcar: |
90 | | - checkers = DEFAULT_CHECKS |
| 138 | + check_list = DEFAULT_CHECKS |
91 | 139 | else: |
92 | | - checkers = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"] |
| 140 | + check_list = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"] |
93 | 141 |
|
94 | | - for check in checkers: |
95 | | - check(fast=fast).check(vf, config["reasons"], config["warnings"]) # type: ignore[arg-type] |
96 | | - if fast and len(config["reasons"]) > 0: |
97 | | - break |
98 | | - return cls(**config, vasp_files=vf, **kwargs) |
| 142 | + config["reasons"], config["warnings"] = cls.run_checks(vf, check_list=check_list, fast=fast) |
| 143 | + validated = cls(**config, vasp_files=vf, **kwargs) |
| 144 | + validated._validated_md5 = vf.md5 |
| 145 | + return validated |
99 | 146 |
|
100 | 147 | @classmethod |
101 | 148 | def from_directory(cls, dir_name: str | Path, **kwargs) -> Self: |
|
0 commit comments