1919
2020import dill as pickle
2121
22+ from pysages .backends .snapshot import Box , Snapshot
2223from pysages .methods import Metadynamics
2324from pysages .methods .core import GriddedSamplingMethod , Result
2425from pysages .typing import Callable
2526from pysages .utils import dispatch , identity
2627
2728
29+ class CompatUnpickler (pickle .Unpickler ):
30+ def find_class (self , module , name ):
31+ if module .endswith ("snapshot" ) and name == "Snapshot" :
32+ return _recreate_snapshot
33+ return super ().find_class (module , name )
34+
35+
2836def load (filename ) -> Result :
2937 """
3038 Loads the state of an previously run `pysages` simulation from a file.
@@ -40,52 +48,21 @@ def load(filename) -> Result:
4048
4149 **Notes:**
4250
43- This function attempts to recover from deserialization errors related to missing
44- `ncalls` attributes .
51+ This function attempts to maintain backwards compatibility with serialized data
52+ structures that have changed in different `pysages` versions .
4553 """
46- with open (filename , "rb" ) as io :
47- bytestring = io .read ()
48-
49- try :
50- return pickle .loads (bytestring )
51-
52- except TypeError as e : # pylint: disable=W0718
53- if "ncalls" not in getattr (e , "message" , repr (e )):
54- raise e
55-
56- # We know that states preceed callbacks so we try to find all tuples of values
57- # corresponding to each state.
58- j = bytestring .find (b"\x8c \x06 states\x94 " )
59- k = bytestring .find (b"\x8c \t callbacks\x94 " )
60- boundary = b"t\x94 \x81 \x94 "
61-
62- marks = []
63- while True :
64- i = j
65- j = bytestring .find (boundary , i + 1 , k )
66- if j == - 1 :
67- marks .append ((i , len (bytestring )))
68- break
69- marks .append ((i , j ))
70-
71- # We set `ncalls` as zero and adjust it later
72- first = marks [0 ]
73- last = marks .pop ()
74- slices = [
75- bytestring [: first [0 ]],
76- * (bytestring [i :j ] + b"K\x00 " for (i , j ) in marks ),
77- bytestring [last [0 ] :], # noqa: E203
78- ]
79- bytestring = b"" .join (slices )
80-
81- # Try to deserialize again
82- result = pickle .loads (bytestring )
83-
84- # Update results with `ncalls` estimates for each state
85- update = _ncalls_estimator (result .method )
86- result .states = [update (state ) for state in result .states ]
87-
88- return result
54+ with open (filename , "rb" ) as f :
55+ unpickler = CompatUnpickler (f )
56+ result = unpickler .load ()
57+
58+ if not isinstance (result , Result ):
59+ raise TypeError ("Only loading of `Result` objects is supported." )
60+
61+ # Update results with `ncalls` estimates for each state
62+ update_ncalls = _ncalls_estimator (result .method )
63+ result .states = [update_ncalls (state ) for state in result .states ]
64+
65+ return result
8966
9067
9168def save (result : Result , filename ) -> None :
@@ -130,3 +107,16 @@ def update(state):
130107 return state ._replace (ncalls = ncalls )
131108
132109 return update
110+
111+
112+ @dispatch
113+ def _recreate_snapshot (* args , ** kwargs ):
114+ # Fallback case: just pass the arguments to the constructor.
115+ return Snapshot (* args , ** kwargs )
116+
117+
118+ @dispatch
119+ def _recreate_snapshot (positions , vel_mass , forces , ids , images , box : Box , dt ):
120+ # Older form: `images` argument was required and preceded `box`.
121+ _extras = () if images is None else (dict (images = images ),)
122+ return Snapshot (positions , vel_mass , forces , ids , box , dt , * _extras )
0 commit comments