|
| 1 | +"""This module implements testing utilities for materials science codes. |
| 2 | +
|
| 3 | +While the primary use is within pymatgen, the functionality is meant to |
| 4 | +be useful for external materials science codes as well. For instance, obtaining |
| 5 | +example crystal structures to perform tests, specialized assert methods for |
| 6 | +materials science, etc. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import json |
| 12 | +import pickle # use pickle over cPickle to get traceback in case of errors |
| 13 | +import string |
| 14 | +from pathlib import Path |
| 15 | +from typing import TYPE_CHECKING |
| 16 | +from unittest import TestCase |
| 17 | + |
| 18 | +import pytest |
| 19 | +from monty.json import MontyDecoder, MontyEncoder, MSONable |
| 20 | +from monty.serialization import loadfn |
| 21 | + |
| 22 | +from pymatgen.core import ROOT, SETTINGS |
| 23 | + |
| 24 | +if TYPE_CHECKING: |
| 25 | + from collections.abc import Sequence |
| 26 | + from typing import Any, ClassVar |
| 27 | + |
| 28 | + from pymatgen.core import Structure |
| 29 | + from pymatgen.util.typing import PathLike |
| 30 | + |
| 31 | +_MODULE_DIR: Path = Path(__file__).absolute().parent |
| 32 | + |
| 33 | +STRUCTURES_DIR: Path = _MODULE_DIR / "structures" |
| 34 | + |
| 35 | +TEST_FILES_DIR: Path = Path(SETTINGS.get("PMG_TEST_FILES_DIR", f"{ROOT}/../tests/files")) |
| 36 | +VASP_IN_DIR: str = f"{TEST_FILES_DIR}/io/vasp/inputs" |
| 37 | +VASP_OUT_DIR: str = f"{TEST_FILES_DIR}/io/vasp/outputs" |
| 38 | + |
| 39 | +# Fake POTCARs have original header information, meaning properties like number of electrons, |
| 40 | +# nuclear charge, core radii, etc. are unchanged (important for testing) while values of the and |
| 41 | +# pseudopotential kinetic energy corrections are scrambled to avoid VASP copyright infringement |
| 42 | +FAKE_POTCAR_DIR: str = f"{VASP_IN_DIR}/fake_potcars" |
| 43 | + |
| 44 | + |
| 45 | +class PymatgenTest(TestCase): |
| 46 | + """Extends unittest.TestCase with several convenient methods for testing: |
| 47 | + - assert_msonable: Test if an object is MSONable and return the serialized object. |
| 48 | + - assert_str_content_equal: Test if two string are equal (ignore whitespaces). |
| 49 | + - get_structure: Load a Structure with its formula. |
| 50 | + - serialize_with_pickle: Test if object(s) can be (de)serialized with `pickle`. |
| 51 | + """ |
| 52 | + |
| 53 | + # dict of lazily-loaded test structures (initialized to None) |
| 54 | + TEST_STRUCTURES: ClassVar[dict[PathLike, Structure | None]] = dict.fromkeys(STRUCTURES_DIR.glob("*")) |
| 55 | + |
| 56 | + @pytest.fixture(autouse=True) |
| 57 | + def _tmp_dir(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: |
| 58 | + """Make all tests run a in a temporary directory accessible via self.tmp_path. |
| 59 | +
|
| 60 | + References: |
| 61 | + https://docs.pytest.org/en/stable/how-to/tmp_path.html |
| 62 | + """ |
| 63 | + monkeypatch.chdir(tmp_path) # change to temporary directory |
| 64 | + self.tmp_path = tmp_path |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def assert_msonable(obj: Any, test_is_subclass: bool = True) -> str: |
| 68 | + """Test if an object is MSONable and verify the contract is fulfilled, |
| 69 | + and return the serialized object. |
| 70 | +
|
| 71 | + By default, the method tests whether obj is an instance of MSONable. |
| 72 | + This check can be deactivated by setting `test_is_subclass` to False. |
| 73 | +
|
| 74 | + Args: |
| 75 | + obj (Any): The object to be checked. |
| 76 | + test_is_subclass (bool): Check if object is an instance of MSONable |
| 77 | + or its subclasses. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + str: Serialized object. |
| 81 | + """ |
| 82 | + obj_name = obj.__class__.__name__ |
| 83 | + |
| 84 | + # Check if is an instance of MONable (or its subclasses) |
| 85 | + if test_is_subclass and not isinstance(obj, MSONable): |
| 86 | + raise TypeError(f"{obj_name} object is not MSONable") |
| 87 | + |
| 88 | + # Check if the object can be accurately reconstructed from its dict representation |
| 89 | + if obj.as_dict() != type(obj).from_dict(obj.as_dict()).as_dict(): |
| 90 | + raise ValueError(f"{obj_name} object could not be reconstructed accurately from its dict representation.") |
| 91 | + |
| 92 | + # Verify that the deserialized object's class is a subclass of the original object's class |
| 93 | + json_str = json.dumps(obj.as_dict(), cls=MontyEncoder) |
| 94 | + round_trip = json.loads(json_str, cls=MontyDecoder) |
| 95 | + if not issubclass(type(round_trip), type(obj)): |
| 96 | + raise TypeError(f"The reconstructed {round_trip.__class__.__name__} object is not a subclass of {obj_name}") |
| 97 | + return json_str |
| 98 | + |
| 99 | + @staticmethod |
| 100 | + def assert_str_content_equal(actual: str, expected: str) -> None: |
| 101 | + """Test if two strings are equal, ignoring whitespaces. |
| 102 | +
|
| 103 | + Args: |
| 104 | + actual (str): The string to be checked. |
| 105 | + expected (str): The reference string. |
| 106 | +
|
| 107 | + Raises: |
| 108 | + AssertionError: When two strings are not equal. |
| 109 | + """ |
| 110 | + strip_whitespace = {ord(c): None for c in string.whitespace} |
| 111 | + if actual.translate(strip_whitespace) != expected.translate(strip_whitespace): |
| 112 | + raise AssertionError( |
| 113 | + "Strings are not equal (whitespaces ignored):\n" |
| 114 | + f"{' Actual '.center(50, '=')}\n" |
| 115 | + f"{actual}\n" |
| 116 | + f"{' Expected '.center(50, '=')}\n" |
| 117 | + f"{expected}\n" |
| 118 | + ) |
| 119 | + |
| 120 | + @classmethod |
| 121 | + def get_structure(cls, name: str) -> Structure: |
| 122 | + """ |
| 123 | + Load a structure from `pymatgen.util.structures`. |
| 124 | +
|
| 125 | + Args: |
| 126 | + name (str): Name of the structure file, for example "LiFePO4". |
| 127 | +
|
| 128 | + Returns: |
| 129 | + Structure |
| 130 | + """ |
| 131 | + try: |
| 132 | + struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{STRUCTURES_DIR}/{name}.json") |
| 133 | + except FileNotFoundError as exc: |
| 134 | + raise FileNotFoundError(f"structure for {name} doesn't exist") from exc |
| 135 | + |
| 136 | + cls.TEST_STRUCTURES[name] = struct |
| 137 | + |
| 138 | + return struct.copy() |
| 139 | + |
| 140 | + def serialize_with_pickle( |
| 141 | + self, |
| 142 | + objects: Any, |
| 143 | + protocols: Sequence[int] | None = None, |
| 144 | + test_eq: bool = True, |
| 145 | + ) -> list: |
| 146 | + """Test whether the object(s) can be serialized and deserialized with |
| 147 | + `pickle`. This method tries to serialize the objects with `pickle` and the |
| 148 | + protocols specified in input. Then it deserializes the pickled format |
| 149 | + and compares the two objects with the `==` operator if `test_eq`. |
| 150 | +
|
| 151 | + Args: |
| 152 | + objects (Any): Object or list of objects. |
| 153 | + protocols (Sequence[int]): List of pickle protocols to test. |
| 154 | + If protocols is None, HIGHEST_PROTOCOL is tested. |
| 155 | + test_eq (bool): If True, the deserialized object is compared |
| 156 | + with the original object using the `__eq__` method. |
| 157 | +
|
| 158 | + Returns: |
| 159 | + list[Any]: Objects deserialized with the specified protocols. |
| 160 | + """ |
| 161 | + # Build a list even when we receive a single object. |
| 162 | + got_single_object = False |
| 163 | + if not isinstance(objects, list | tuple): |
| 164 | + got_single_object = True |
| 165 | + objects = [objects] |
| 166 | + |
| 167 | + protocols = protocols or [pickle.HIGHEST_PROTOCOL] |
| 168 | + |
| 169 | + # This list will contain the objects deserialized with the different protocols. |
| 170 | + objects_by_protocol, errors = [], [] |
| 171 | + |
| 172 | + for protocol in protocols: |
| 173 | + # Serialize and deserialize the object. |
| 174 | + tmpfile = self.tmp_path / f"tempfile_{protocol}.pkl" |
| 175 | + |
| 176 | + try: |
| 177 | + with open(tmpfile, "wb") as file: |
| 178 | + pickle.dump(objects, file, protocol=protocol) |
| 179 | + except Exception as exc: |
| 180 | + errors.append(f"pickle.dump with {protocol=} raised:\n{exc}") |
| 181 | + continue |
| 182 | + |
| 183 | + try: |
| 184 | + with open(tmpfile, "rb") as file: |
| 185 | + unpickled_objs = pickle.load(file) # noqa: S301 |
| 186 | + except Exception as exc: |
| 187 | + errors.append(f"pickle.load with {protocol=} raised:\n{exc}") |
| 188 | + continue |
| 189 | + |
| 190 | + # Test for equality |
| 191 | + if test_eq: |
| 192 | + for orig, unpickled in zip(objects, unpickled_objs, strict=True): |
| 193 | + if orig != unpickled: |
| 194 | + raise ValueError( |
| 195 | + f"Unpickled and original objects are unequal for {protocol=}\n{orig=}\n{unpickled=}" |
| 196 | + ) |
| 197 | + |
| 198 | + # Save the deserialized objects and test for equality. |
| 199 | + objects_by_protocol.append(unpickled_objs) |
| 200 | + |
| 201 | + if errors: |
| 202 | + raise ValueError("\n".join(errors)) |
| 203 | + |
| 204 | + # Return list so that client code can perform additional tests |
| 205 | + if got_single_object: |
| 206 | + return [o[0] for o in objects_by_protocol] |
| 207 | + return objects_by_protocol |
0 commit comments