9
9
10
10
import json
11
11
import string
12
- import tempfile
13
12
import unittest
14
13
from pathlib import Path
15
- from typing import ClassVar
14
+ from typing import TYPE_CHECKING , Any , ClassVar
16
15
17
16
import pytest
18
17
from monty .json import MontyDecoder , MSONable
19
18
from monty .serialization import loadfn
20
19
21
20
from pymatgen .core import SETTINGS , Structure
22
21
22
+ if TYPE_CHECKING :
23
+ from git import Sequence
24
+
23
25
MODULE_DIR = Path (__file__ ).absolute ().parent
24
26
25
27
TEST_FILES_DIR = Path (SETTINGS .get ("PMG_TEST_FILES_DIR" , MODULE_DIR / ".." / ".." / "tests" / "files" ))
@@ -51,6 +53,7 @@ def get_structure(cls, name: str) -> Structure:
51
53
Structure
52
54
"""
53
55
struct = cls .TEST_STRUCTURES .get (name ) or loadfn (f"{ cls .STRUCTURES_DIR } /{ name } .json" )
56
+ cls .TEST_STRUCTURES [name ] = struct
54
57
return struct .copy ()
55
58
56
59
@staticmethod
@@ -59,7 +62,7 @@ def assert_str_content_equal(actual, expected):
59
62
strip_whitespace = {ord (c ): None for c in string .whitespace }
60
63
return actual .translate (strip_whitespace ) == expected .translate (strip_whitespace )
61
64
62
- def serialize_with_pickle (self , objects , protocols = None , test_eq = True ):
65
+ def serialize_with_pickle (self , objects : Any , protocols : Sequence [ int ] = None , test_eq : bool = True ):
63
66
"""Test whether the object(s) can be serialized and deserialized with
64
67
pickle. This method tries to serialize the objects with pickle and the
65
68
protocols specified in input. Then it deserializes the pickle format
@@ -77,7 +80,7 @@ def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
77
80
Nested list with the objects deserialized with the specified
78
81
protocols.
79
82
"""
80
- # Use the python version so that we get the traceback in case of errors
83
+ # use pickle, not cPickle so that we get the traceback in case of errors
81
84
import pickle
82
85
83
86
# Build a list even when we receive a single object.
@@ -86,39 +89,38 @@ def serialize_with_pickle(self, objects, protocols=None, test_eq=True):
86
89
got_single_object = True
87
90
objects = [objects ]
88
91
89
- if protocols is None :
90
- protocols = [pickle .HIGHEST_PROTOCOL ]
92
+ protocols = protocols or [pickle .HIGHEST_PROTOCOL ]
91
93
92
- # This list will contains the object deserialized with the different
93
- # protocols.
94
+ # This list will contain the objects deserialized with the different protocols.
94
95
objects_by_protocol , errors = [], []
95
96
96
97
for protocol in protocols :
97
98
# Serialize and deserialize the object.
98
- mode = "wb"
99
- fd , tmpfile = tempfile .mkstemp (text = "b" not in mode )
99
+ tmpfile = self .tmp_path / f"tempfile_{ protocol } .pkl"
100
100
101
101
try :
102
- with open (tmpfile , mode ) as fh :
102
+ with open (tmpfile , "wb" ) as fh :
103
103
pickle .dump (objects , fh , protocol = protocol )
104
104
except Exception as exc :
105
105
errors .append (f"pickle.dump with { protocol = } raised:\n { exc } " )
106
106
continue
107
107
108
108
try :
109
109
with open (tmpfile , "rb" ) as fh :
110
- new_objects = pickle .load (fh )
110
+ unpickled_objs = pickle .load (fh )
111
111
except Exception as exc :
112
112
errors .append (f"pickle.load with { protocol = } raised:\n { exc } " )
113
113
continue
114
114
115
115
# Test for equality
116
116
if test_eq :
117
- for old_obj , new_obj in zip (objects , new_objects ):
118
- assert old_obj == new_obj
117
+ for orig , unpickled in zip (objects , unpickled_objs ):
118
+ assert (
119
+ orig == unpickled
120
+ ), f"Unpickled and original objects are unequal for { protocol = } \n { orig = } \n { unpickled = } "
119
121
120
122
# Save the deserialized objects and test for equality.
121
- objects_by_protocol .append (new_objects )
123
+ objects_by_protocol .append (unpickled_objs )
122
124
123
125
if errors :
124
126
raise ValueError ("\n " .join (errors ))
0 commit comments