Skip to content

Commit 1cfc267

Browse files
committed
improve PymatgenTest.serialize_with_pickle: use self.tmp_path, add type hints and better err msg if test_eq fails
clean up tests using serialize_with_pickle()
1 parent 6600998 commit 1cfc267

File tree

8 files changed

+51
-54
lines changed

8 files changed

+51
-54
lines changed

pymatgen/util/testing.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99

1010
import json
1111
import string
12-
import tempfile
1312
import unittest
1413
from pathlib import Path
15-
from typing import ClassVar
14+
from typing import TYPE_CHECKING, Any, ClassVar
1615

1716
import pytest
1817
from monty.json import MontyDecoder, MSONable
1918
from monty.serialization import loadfn
2019

2120
from pymatgen.core import SETTINGS, Structure
2221

22+
if TYPE_CHECKING:
23+
from git import Sequence
24+
2325
MODULE_DIR = Path(__file__).absolute().parent
2426

2527
TEST_FILES_DIR = Path(SETTINGS.get("PMG_TEST_FILES_DIR", MODULE_DIR / ".." / ".." / "tests" / "files"))
@@ -51,6 +53,7 @@ def get_structure(cls, name: str) -> Structure:
5153
Structure
5254
"""
5355
struct = cls.TEST_STRUCTURES.get(name) or loadfn(f"{cls.STRUCTURES_DIR}/{name}.json")
56+
cls.TEST_STRUCTURES[name] = struct
5457
return struct.copy()
5558

5659
@staticmethod
@@ -59,7 +62,7 @@ def assert_str_content_equal(actual, expected):
5962
strip_whitespace = {ord(c): None for c in string.whitespace}
6063
return actual.translate(strip_whitespace) == expected.translate(strip_whitespace)
6164

62-
def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
65+
def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] = None, test_eq: bool = True):
6366
"""Test whether the object(s) can be serialized and deserialized with
6467
pickle. This method tries to serialize the objects with pickle and the
6568
protocols specified in input. Then it deserializes the pickle format
@@ -77,7 +80,7 @@ def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
7780
Nested list with the objects deserialized with the specified
7881
protocols.
7982
"""
80-
# Use the python version so that we get the traceback in case of errors
83+
# use pickle, not cPickle so that we get the traceback in case of errors
8184
import pickle
8285

8386
# Build a list even when we receive a single object.
@@ -86,39 +89,38 @@ def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
8689
got_single_object = True
8790
objects = [objects]
8891

89-
if protocols is None:
90-
protocols = [pickle.HIGHEST_PROTOCOL]
92+
protocols = protocols or [pickle.HIGHEST_PROTOCOL]
9193

92-
# This list will contains the object deserialized with the different
93-
# protocols.
94+
# This list will contain the objects deserialized with the different protocols.
9495
objects_by_protocol, errors = [], []
9596

9697
for protocol in protocols:
9798
# Serialize and deserialize the object.
98-
mode = "wb"
99-
fd, tmpfile = tempfile.mkstemp(text="b" not in mode)
99+
tmpfile = self.tmp_path / f"tempfile_{protocol}.pkl"
100100

101101
try:
102-
with open(tmpfile, mode) as fh:
102+
with open(tmpfile, "wb") as fh:
103103
pickle.dump(objects, fh, protocol=protocol)
104104
except Exception as exc:
105105
errors.append(f"pickle.dump with {protocol=} raised:\n{exc}")
106106
continue
107107

108108
try:
109109
with open(tmpfile, "rb") as fh:
110-
new_objects = pickle.load(fh)
110+
unpickled_objs = pickle.load(fh)
111111
except Exception as exc:
112112
errors.append(f"pickle.load with {protocol=} raised:\n{exc}")
113113
continue
114114

115115
# Test for equality
116116
if test_eq:
117-
for old_obj, new_obj in zip(objects, new_objects):
118-
assert old_obj == new_obj
117+
for orig, unpickled in zip(objects, unpickled_objs):
118+
assert (
119+
orig == unpickled
120+
), f"Unpickled and original objects are unequal for {protocol=}\n{orig=}\n{unpickled=}"
119121

120122
# Save the deserialized objects and test for equality.
121-
objects_by_protocol.append(new_objects)
123+
objects_by_protocol.append(unpickled_objs)
122124

123125
if errors:
124126
raise ValueError("\n".join(errors))

tests/core/test_composition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,9 @@ def test_as_dict(self):
378378
assert d["O"] == correct_dict["O"]
379379

380380
def test_pickle(self):
381-
for c in self.comps:
382-
self.serialize_with_pickle(c, test_eq=True)
383-
self.serialize_with_pickle(c.to_data_dict, test_eq=True)
381+
for comp in self.comps:
382+
self.serialize_with_pickle(comp)
383+
self.serialize_with_pickle(comp.to_data_dict)
384384

385385
def test_to_data_dict(self):
386386
comp = Composition("Fe0.00009Ni0.99991")

tests/core/test_libxcfunc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_libxcfunc_api(self):
1919
assert xc.kind in LibxcFunc.all_kinds()
2020

2121
# Test if object can be serialized with Pickle.
22-
self.serialize_with_pickle(xc, test_eq=True)
22+
self.serialize_with_pickle(xc)
2323

2424
# Test if object supports MSONable
2525
self.assert_msonable(xc, test_is_subclass=False)

tests/core/test_periodic_table.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,12 @@ def test_sort(self):
336336
assert sorted(els) == [Element.C, Element.Se]
337337

338338
def test_pickle(self):
339-
el1 = Element.Fe
340-
o = pickle.dumps(el1)
341-
assert el1 == pickle.loads(o)
339+
pickled = pickle.dumps(Element.Fe)
340+
assert Element.Fe == pickle.loads(pickled)
342341

343-
# Test all elements up to Uranium
344-
for i in range(1, 93):
345-
self.serialize_with_pickle(Element.from_Z(i), test_eq=True)
342+
# Test 5 random elements
343+
for idx in np.random.randint(1, 104, size=5):
344+
self.serialize_with_pickle(Element.from_Z(idx))
346345

347346
def test_print_periodic_table(self):
348347
Element.print_periodic_table()
@@ -385,24 +384,24 @@ def test_attr(self):
385384
assert self.specie4.spin == 5
386385

387386
def test_deepcopy(self):
388-
el1 = Species("Fe", 4)
389-
el2 = Species("Na", 1)
390-
ellist = [el1, el2]
391-
assert ellist == deepcopy(ellist), "Deepcopy operation doesn't produce exact copy."
387+
el1 = Species("Fe4+")
388+
el2 = Species("Na1+")
389+
elem_list = [el1, el2]
390+
assert elem_list == deepcopy(elem_list), "Deepcopy operation doesn't produce exact copy."
392391

393392
def test_pickle(self):
394393
assert self.specie1 == pickle.loads(pickle.dumps(self.specie1))
395-
for i in range(1, 5):
396-
self.serialize_with_pickle(getattr(self, f"specie{i}"), test_eq=True)
397-
cs = Species("Cs", 1)
398-
cl = Species("Cl", 1)
394+
for idx in range(1, 5):
395+
self.serialize_with_pickle(getattr(self, f"specie{idx}"))
396+
cs = Species("Cs1+")
397+
cl = Species("Cl1+")
399398

400-
with open("cscl.pickle", "wb") as f:
401-
pickle.dump((cs, cl), f)
399+
with open("cscl.pickle", "wb") as file:
400+
pickle.dump((cs, cl), file)
402401

403-
with open("cscl.pickle", "rb") as f:
404-
d = pickle.load(f)
405-
assert d == (cs, cl)
402+
with open("cscl.pickle", "rb") as file:
403+
tup = pickle.load(file)
404+
assert tup == (cs, cl)
406405
os.remove("cscl.pickle")
407406

408407
def test_get_crystal_field_spin(self):

tests/core/test_units.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,4 @@ def test_pickle(self):
268268
b = cls(10, "N bohr")
269269
objects = [a, b]
270270

271-
new_objects_from_protocol = self.serialize_with_pickle(objects)
272-
273-
for new_objects in new_objects_from_protocol:
274-
for old_item, new_item in zip(objects, new_objects):
275-
assert str(old_item) == str(new_item)
271+
self.serialize_with_pickle(objects)

tests/core/test_xcfunc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_xcfunc_api(self):
4545
assert ixc_11 in d
4646

4747
# Test if object can be serialized with Pickle.
48-
self.serialize_with_pickle(ixc_11, test_eq=True)
48+
self.serialize_with_pickle(ixc_11)
4949

5050
# Test if object supports MSONable
5151
# TODO

tests/io/abinit/test_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_api(self):
122122
assert inp.structure == new_structure
123123

124124
# Compatible with Pickle and MSONable?
125-
self.serialize_with_pickle(inp, test_eq=False)
125+
self.serialize_with_pickle(inp)
126126

127127
def test_input_errors(self):
128128
"""Testing typical BasicAbinitInput Error."""

tests/io/abinit/test_pseudos.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_nc_pseudos(self):
5555
assert pseudo.nlcc_radius >= 0.0
5656

5757
# Test pickle
58-
self.serialize_with_pickle(pseudo, test_eq=False)
58+
self.serialize_with_pickle(pseudo)
5959

6060
# Test MSONable
6161
self.assert_msonable(pseudo)
@@ -117,16 +117,16 @@ def test_pawxml_pseudos(self):
117117
assert oxygen.paw_radius == approx(1.4146523028)
118118

119119
# Test pickle
120-
new_objs = self.serialize_with_pickle(oxygen, test_eq=False)
120+
new_objs = self.serialize_with_pickle(oxygen)
121121
# Test MSONable
122122
self.assert_msonable(oxygen)
123123

124-
for o in new_objs:
125-
assert o.ispaw
126-
assert o.symbol == "O"
127-
assert (o.Z, o.core, o.valence) == (8, 2, 6), o.Z_val == 6
124+
for obj in new_objs:
125+
assert obj.ispaw
126+
assert obj.symbol == "O"
127+
assert (obj.Z, obj.core, obj.valence) == (8, 2, 6), obj.Z_val == 6
128128

129-
assert o.paw_radius == approx(1.4146523028)
129+
assert obj.paw_radius == approx(1.4146523028)
130130

131131
def test_oncvpsp_pseudo_sr(self):
132132
"""Test the ONCVPSP Ge pseudo (scalar relativistic version)."""
@@ -147,7 +147,7 @@ def test_oncvpsp_pseudo_sr(self):
147147
assert not ger.supports_soc
148148

149149
# Data persistence
150-
self.serialize_with_pickle(ger, test_eq=False)
150+
self.serialize_with_pickle(ger)
151151
self.assert_msonable(ger)
152152

153153
def test_oncvpsp_pseudo_fr(self):
@@ -157,7 +157,7 @@ def test_oncvpsp_pseudo_fr(self):
157157
str(pb)
158158

159159
# Data persistence
160-
self.serialize_with_pickle(pb, test_eq=False)
160+
self.serialize_with_pickle(pb)
161161
self.assert_msonable(pb)
162162

163163
assert pb.symbol == "Pb"

0 commit comments

Comments
 (0)