Skip to content

Commit c1062e6

Browse files
add rudimentary hashing to prevent user manipulation
1 parent 190b154 commit c1062e6

File tree

5 files changed

+98
-16
lines changed

5 files changed

+98
-16
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from pymatgen.io.validation import VaspValidator
2121
validation_doc = VaspValidator.from_directory(path_to_vasp_calculation_directory)
2222
```
2323

24-
In the above case, whether a calculation passes the validator can be accessed via `validation_doc.is_valid`. Moreover, reasons for an invalidated calculation can be accessed via `validation_doc.reasons` (this will be empty for valid calculations). Last but not least, warnings for potential issues (sometimes minor, sometimes major) can be accessed via `validation_doc.warnings`.
24+
In the above case, whether a calculation passes the validator can be accessed via `validation_doc.valid`. Moreover, reasons for an invalidated calculation can be accessed via `validation_doc.reasons` (this will be empty for valid calculations). Last but not least, warnings for potential issues (sometimes minor, sometimes major) can be accessed via `validation_doc.warnings`.
2525

2626
Contributors
2727
=====

examples/using_validation_docs.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
"outputs": [],
106106
"source": [
107107
"mp_compliant_doc = check_calc(\"MP_compliant\")\n",
108-
"print(mp_compliant_doc.is_valid)"
108+
"print(mp_compliant_doc.valid)"
109109
]
110110
},
111111
{
@@ -127,7 +127,7 @@
127127
"outputs": [],
128128
"source": [
129129
"mp_non_compliant_doc = check_calc(\"MP_non_compliant\")\n",
130-
"print(mp_non_compliant_doc.is_valid)\n",
130+
"print(mp_non_compliant_doc.valid)\n",
131131
"for reason in mp_non_compliant_doc.reasons:\n",
132132
" print(reason)"
133133
]

pymatgen/io/validation/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from __future__ import annotations
44

55
from functools import cached_property
6+
import hashlib
67
from importlib import import_module
78
import os
9+
import numpy as np
810
from pathlib import Path
911
from pydantic import BaseModel, Field, model_serializer, PrivateAttr
1012
from typing import TYPE_CHECKING, Any, Optional
@@ -115,6 +117,20 @@ def from_vasprun(cls, vasprun: Vasprun) -> Self:
115117
bandgap=vasprun.get_band_structure(efermi="smart").get_band_gap()["energy"],
116118
)
117119

120+
@model_serializer
121+
def deserialize_objects(self) -> dict[str, Any]:
122+
"""Ensure all pymatgen objects are deserialized."""
123+
model_dumped = {k: getattr(self, k) for k in self.__class__.model_fields}
124+
for k in ("final_structure", "kpoints"):
125+
model_dumped[k] = model_dumped[k].as_dict()
126+
for iion, istep in enumerate(model_dumped["ionic_steps"]):
127+
if (istruct := istep.get("structure")) and isinstance(istruct, Structure):
128+
model_dumped["ionic_steps"][iion]["structure"] = istruct.as_dict()
129+
for k in ("forces", "stress"):
130+
if (val := istep.get(k)) is not None and isinstance(val, np.ndarray):
131+
model_dumped["ionic_steps"][iion][k] = val.tolist()
132+
return model_dumped
133+
118134

119135
class VaspInputSafe(BaseModel):
120136
"""Stricter VaspInputSet with no POTCAR info."""
@@ -183,6 +199,11 @@ class VaspFiles(BaseModel):
183199
outcar: Optional[LightOutcar] = None
184200
vasprun: Optional[LightVasprun] = None
185201

202+
@property
203+
def md5(self) -> str:
204+
"""Get MD5 of VaspFiles for use in validation checks."""
205+
return hashlib.md5(self.model_dump_json().encode()).hexdigest()
206+
186207
@property
187208
def actual_kpoints(self) -> Kpoints | None:
188209
"""The actual KPOINTS / IBZKPT used in the calculation, if applicable."""

pymatgen/io/validation/validation.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44
from pathlib import Path
5-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, PrivateAttr
66
from typing import TYPE_CHECKING
77

88
from monty.os.path import zpath
@@ -18,6 +18,7 @@
1818
import os
1919
from typing_extensions import Self
2020

21+
2122
DEFAULT_CHECKS = [CheckStructureProperties, CheckPotcar, CheckCommonErrors, CheckKpointsKspacing, CheckIncar]
2223

2324
# TODO: check for surface/slab calculations. Especially necessary for external calcs.
@@ -27,20 +28,67 @@
2728
class VaspValidator(BaseModel):
2829
"""Validate a VASP calculation."""
2930

31+
vasp_files: VaspFiles = Field(description="The VASP I/O.")
3032
reasons: list[str] = Field([], description="List of deprecation tags detailing why this task isn't valid")
3133
warnings: list[str] = Field([], description="List of warnings about this calculation")
32-
vasp_files: VaspFiles = Field(description="The VASP I/O.")
34+
35+
_validated_md5: str | None = PrivateAttr(None)
3336

3437
@property
35-
def is_valid(self) -> bool:
36-
"""Determine if the calculation is valid."""
38+
def valid(self) -> bool:
39+
"""Determine if the calculation is valid after ensuring inputs have not changed."""
40+
self.recheck()
3741
return len(self.reasons) == 0
3842

3943
@property
4044
def has_warnings(self) -> bool:
4145
"""Determine if any warnings were incurred."""
4246
return len(self.warnings) > 0
4347

48+
def recheck(self) -> None:
49+
"""Rerun validation, prioritizing speed."""
50+
new_md5 = None
51+
if self._validated_md5 is None or (new_md5 := self.vasp_files.md5) != self._validated_md5:
52+
53+
if self.vasp_files.user_input.potcar:
54+
check_list = DEFAULT_CHECKS
55+
else:
56+
check_list = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"]
57+
self.reasons, self.warnings = self.run_checks(self.vasp_files, check_list=check_list, fast=True)
58+
self._validated_md5 = new_md5 or self.vasp_files.md5
59+
60+
@staticmethod
61+
def run_checks(
62+
vasp_files: VaspFiles,
63+
check_list: list | tuple = DEFAULT_CHECKS,
64+
fast: bool = False,
65+
) -> tuple[list[str], list[str]]:
66+
"""Perform validation.
67+
68+
Parameters
69+
-----------
70+
vasp_files : VaspFiles
71+
The VASP I/O to validate.
72+
check_list : list or tuple of BaseValidator.
73+
The list of checks to perform. Defaults to `DEFAULT_CHECKS`.
74+
fast : bool (default = False)
75+
Whether to stop validation at the first validation failure (True)
76+
or compile a list of all failure reasons.
77+
78+
Returns
79+
-----------
80+
tuple of list of str
81+
The first list are all reasons for validation failure,
82+
the second list contains all warnings.
83+
"""
84+
reasons: list[str] = []
85+
warnings: list[str] = []
86+
for check in check_list:
87+
check(fast=fast).check(vasp_files, reasons, warnings) # type: ignore[arg-type]
88+
if fast and len(reasons) > 0:
89+
break
90+
return reasons, warnings
91+
4492
@classmethod
4593
def from_vasp_input(
4694
cls,
@@ -87,15 +135,14 @@ def from_vasp_input(
87135
}
88136

89137
if check_potcar:
90-
checkers = DEFAULT_CHECKS
138+
check_list = DEFAULT_CHECKS
91139
else:
92-
checkers = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"]
140+
check_list = [c for c in DEFAULT_CHECKS if c.__name__ != "CheckPotcar"]
93141

94-
for check in checkers:
95-
check(fast=fast).check(vf, config["reasons"], config["warnings"]) # type: ignore[arg-type]
96-
if fast and len(config["reasons"]) > 0:
97-
break
98-
return cls(**config, vasp_files=vf, **kwargs)
142+
config["reasons"], config["warnings"] = cls.run_checks(vf, check_list=check_list, fast=fast)
143+
validated = cls(**config, vasp_files=vf, **kwargs)
144+
validated._validated_md5 = vf.md5
145+
return validated
99146

100147
@classmethod
101148
def from_directory(cls, dir_name: str | Path, **kwargs) -> Self:

tests/test_validation.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def test_validation_from_files(test_dir):
5454
for k in ("incar", "structure", "kpoints")
5555
)
5656

57+
# Ensure that user modifcation to inputs after submitting valid
58+
# input leads to subsequent validation failures.
59+
# Re-instantiate VaspValidator to ensure pointers don't get messed up
60+
validated = VaspValidator(**validator_from_paths.model_dump())
61+
og_md5 = validated.vasp_files.md5
62+
assert validated.valid
63+
assert validated._validated_md5 == og_md5
64+
65+
validated.vasp_files.user_input.incar["ENCUT"] = 1.0
66+
new_md5 = validated.vasp_files.md5
67+
assert new_md5 != og_md5
68+
assert not validated.valid
69+
assert validated._validated_md5 == new_md5
70+
5771

5872
@pytest.mark.parametrize(
5973
"object_name",
@@ -470,7 +484,7 @@ def test_fast_mode():
470484
validated = VaspValidator.from_vasp_input(vasp_files=vf, check_potcar=False)
471485

472486
# Without POTCAR check, this doc is valid
473-
assert validated.is_valid
487+
assert validated.valid
474488

475489
# Now introduce sequence of changes to test how fast validation works
476490
# Check order:
@@ -532,7 +546,7 @@ def test_site_properties(test_dir):
532546
vf = VaspFiles(**loadfn(test_dir / "vasp" / "mp-1245223_site_props_check.json.gz"))
533547
vd = VaspValidator.from_vasp_input(vasp_files=vf)
534548

535-
assert not vd.is_valid
549+
assert not vd.valid
536550
assert any("selective dynamics" in reason.lower() for reason in vd.reasons)
537551

538552
# map non-zero velocities to input structure and re-check

0 commit comments

Comments
 (0)