Skip to content

Commit aab86a9

Browse files
committed
Make sure serialization is backwards compatible
1 parent a4d8db5 commit aab86a9

File tree

1 file changed

+35
-45
lines changed

1 file changed

+35
-45
lines changed

pysages/serialization.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,20 @@
1919

2020
import dill as pickle
2121

22+
from pysages.backends.snapshot import Box, Snapshot
2223
from pysages.methods import Metadynamics
2324
from pysages.methods.core import GriddedSamplingMethod, Result
2425
from pysages.typing import Callable
2526
from pysages.utils import dispatch, identity
2627

2728

29+
class CompatUnpickler(pickle.Unpickler):
30+
def find_class(self, module, name):
31+
if module.endswith("snapshot") and name == "Snapshot":
32+
return _recreate_snapshot
33+
return super().find_class(module, name)
34+
35+
2836
def load(filename) -> Result:
2937
"""
3038
Loads the state of an previously run `pysages` simulation from a file.
@@ -40,52 +48,21 @@ def load(filename) -> Result:
4048
4149
**Notes:**
4250
43-
This function attempts to recover from deserialization errors related to missing
44-
`ncalls` attributes.
51+
This function attempts to maintain backwards compatibility with serialized data
52+
structures that have changed in different `pysages` versions.
4553
"""
46-
with open(filename, "rb") as io:
47-
bytestring = io.read()
48-
49-
try:
50-
return pickle.loads(bytestring)
51-
52-
except TypeError as e: # pylint: disable=W0718
53-
if "ncalls" not in getattr(e, "message", repr(e)):
54-
raise e
55-
56-
# We know that states preceed callbacks so we try to find all tuples of values
57-
# corresponding to each state.
58-
j = bytestring.find(b"\x8c\x06states\x94")
59-
k = bytestring.find(b"\x8c\tcallbacks\x94")
60-
boundary = b"t\x94\x81\x94"
61-
62-
marks = []
63-
while True:
64-
i = j
65-
j = bytestring.find(boundary, i + 1, k)
66-
if j == -1:
67-
marks.append((i, len(bytestring)))
68-
break
69-
marks.append((i, j))
70-
71-
# We set `ncalls` as zero and adjust it later
72-
first = marks[0]
73-
last = marks.pop()
74-
slices = [
75-
bytestring[: first[0]],
76-
*(bytestring[i:j] + b"K\x00" for (i, j) in marks),
77-
bytestring[last[0] :], # noqa: E203
78-
]
79-
bytestring = b"".join(slices)
80-
81-
# Try to deserialize again
82-
result = pickle.loads(bytestring)
83-
84-
# Update results with `ncalls` estimates for each state
85-
update = _ncalls_estimator(result.method)
86-
result.states = [update(state) for state in result.states]
87-
88-
return result
54+
with open(filename, "rb") as f:
55+
unpickler = CompatUnpickler(f)
56+
result = unpickler.load()
57+
58+
if not isinstance(result, Result):
59+
raise TypeError("Only loading of `Result` objects is supported.")
60+
61+
# Update results with `ncalls` estimates for each state
62+
update_ncalls = _ncalls_estimator(result.method)
63+
result.states = [update_ncalls(state) for state in result.states]
64+
65+
return result
8966

9067

9168
def save(result: Result, filename) -> None:
@@ -130,3 +107,16 @@ def update(state):
130107
return state._replace(ncalls=ncalls)
131108

132109
return update
110+
111+
112+
@dispatch
113+
def _recreate_snapshot(*args, **kwargs):
114+
# Fallback case: just pass the arguments to the constructor.
115+
return Snapshot(*args, **kwargs)
116+
117+
118+
@dispatch
119+
def _recreate_snapshot(positions, vel_mass, forces, ids, images, box: Box, dt):
120+
# Older form: `images` argument was required and preceded `box`.
121+
_extras = () if images is None else (dict(images=images),)
122+
return Snapshot(positions, vel_mass, forces, ids, box, dt, *_extras)

0 commit comments

Comments
 (0)