5
5
unit tests or regression tests.
6
6
"""
7
7
8
- import os
9
8
import pickle
10
9
import sys
11
- import tempfile
12
- import zipfile
13
- from collections import Counter
14
- from contextlib import closing
15
- from io import BytesIO
16
- from pickle import HIGHEST_PROTOCOL
17
-
18
- import numpy as np
19
10
20
11
import pytensor
21
12
22
13
23
- try :
24
- from pickle import DEFAULT_PROTOCOL
25
- except ImportError :
26
- DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
27
-
28
- from pytensor .compile .sharedvalue import SharedVariable
29
-
30
-
31
14
__docformat__ = "restructuredtext en"
32
15
__authors__ = "Pascal Lamblin " "PyMC Developers " "PyTensor Developers "
33
16
__copyright__ = "Copyright 2013, Universite de Montreal"
@@ -49,16 +32,18 @@ class StripPickler(Pickler):
49
32
50
33
..code-block:: python
51
34
52
- fn_args = dict(inputs=inputs,
53
- outputs=outputs,
54
- updates=updates)
55
- dest_pkl = 'my_test.pkl'
56
- with open(dest_pkl, 'wb') as f:
35
+ fn_args = {
36
+ "inputs": inputs,
37
+ "outputs": outputs,
38
+ "updates": updates,
39
+ }
40
+ dest_pkl = "my_test.pkl"
41
+ with Path(dest_pkl).open("wb") as f:
57
42
strip_pickler = StripPickler(f, protocol=-1)
58
43
strip_pickler.dump(fn_args)
59
44
"""
60
45
61
- def __init__ (self , file , protocol = 0 , extra_tag_to_remove = None ):
46
+ def __init__ (self , file , protocol : int = 0 , extra_tag_to_remove : str | None = None ):
62
47
# Can't use super as Pickler isn't a new style class
63
48
super ().__init__ (file , protocol )
64
49
self .tag_to_remove = ["trace" , "test_value" ]
@@ -77,226 +62,3 @@ def save(self, obj):
77
62
del obj .__dict__ ["__doc__" ]
78
63
79
64
return Pickler .save (self , obj )
80
-
81
-
82
- class PersistentNdarrayID :
83
- """Persist ndarrays in an object by saving them to a zip file.
84
-
85
- :param zip_file: A zip file handle that the NumPy arrays will be saved to.
86
- :type zip_file: :class:`zipfile.ZipFile`
87
-
88
-
89
- .. note:
90
- The convention for persistent ids given by this class and its derived
91
- classes is that the name should take the form `type.name` where `type`
92
- can be used by the persistent loader to determine how to load the
93
- object, while `name` is human-readable and as descriptive as possible.
94
-
95
- """
96
-
97
- def __init__ (self , zip_file ):
98
- self .zip_file = zip_file
99
- self .count = 0
100
- self .seen = {}
101
-
102
- def _resolve_name (self , obj ):
103
- """Determine the name the object should be saved under."""
104
- name = f"array_{ self .count } "
105
- self .count += 1
106
- return name
107
-
108
- def __call__ (self , obj ):
109
- if isinstance (obj , np .ndarray ):
110
- if id (obj ) not in self .seen :
111
-
112
- def write_array (f ):
113
- np .lib .format .write_array (f , obj )
114
-
115
- name = self ._resolve_name (obj )
116
- zipadd (write_array , self .zip_file , name )
117
- self .seen [id (obj )] = f"ndarray.{ name } "
118
- return self .seen [id (obj )]
119
-
120
-
121
- class PersistentSharedVariableID (PersistentNdarrayID ):
122
- """Uses shared variable names when persisting to zip file.
123
-
124
- If a shared variable has a name, this name is used as the name of the
125
- NPY file inside of the zip file. NumPy arrays that aren't matched to a
126
- shared variable are persisted as usual (i.e. `array_0`, `array_1`,
127
- etc.)
128
-
129
- :param allow_unnamed: Allow shared variables without a name to be
130
- persisted. Defaults to ``True``.
131
- :type allow_unnamed: bool, optional
132
-
133
- :param allow_duplicates: Allow multiple shared variables to have the same
134
- name, in which case they will be numbered e.g. `x`, `x_2`, `x_3`, etc.
135
- Defaults to ``True``.
136
- :type allow_duplicates: bool, optional
137
-
138
- :raises ValueError
139
- If an unnamed shared variable is encountered and `allow_unnamed` is
140
- ``False``, or if two shared variables have the same name, and
141
- `allow_duplicates` is ``False``.
142
-
143
- """
144
-
145
- def __init__ (self , zip_file , allow_unnamed = True , allow_duplicates = True ):
146
- super ().__init__ (zip_file )
147
- self .name_counter = Counter ()
148
- self .ndarray_names = {}
149
- self .allow_unnamed = allow_unnamed
150
- self .allow_duplicates = allow_duplicates
151
-
152
- def _resolve_name (self , obj ):
153
- if id (obj ) in self .ndarray_names :
154
- name = self .ndarray_names [id (obj )]
155
- count = self .name_counter [name ]
156
- self .name_counter [name ] += 1
157
- if count :
158
- if not self .allow_duplicates :
159
- raise ValueError (
160
- f"multiple shared variables with the name `{ name } ` found"
161
- )
162
- name = f"{ name } _{ count + 1 } "
163
- return name
164
- return super ()._resolve_name (obj )
165
-
166
- def __call__ (self , obj ):
167
- if isinstance (obj , SharedVariable ):
168
- if obj .name :
169
- if obj .name == "pkl" :
170
- ValueError ("can't pickle shared variable with name `pkl`" )
171
- self .ndarray_names [id (obj .container .storage [0 ])] = obj .name
172
- elif not self .allow_unnamed :
173
- raise ValueError (f"unnamed shared variable, { obj } " )
174
- return super ().__call__ (obj )
175
-
176
-
177
- class PersistentNdarrayLoad :
178
- """Load NumPy arrays that were persisted to a zip file when pickling.
179
-
180
- :param zip_file: The zip file handle in which the NumPy arrays are saved.
181
- :type zip_file: :class:`zipfile.ZipFile`
182
-
183
- """
184
-
185
- def __init__ (self , zip_file ):
186
- self .zip_file = zip_file
187
- self .cache = {}
188
-
189
- def __call__ (self , persid ):
190
- array_type , name = persid .split ("." )
191
- del array_type
192
- # array_type was used for switching gpu/cpu arrays
193
- # it is better to put these into sublclasses properly
194
- # this is more work but better logic
195
- if name in self .cache :
196
- return self .cache [name ]
197
- ret = None
198
- with self .zip_file .open (name ) as f :
199
- ret = np .lib .format .read_array (f )
200
- self .cache [name ] = ret
201
- return ret
202
-
203
-
204
- def dump (
205
- obj ,
206
- file_handler ,
207
- protocol = DEFAULT_PROTOCOL ,
208
- persistent_id = PersistentSharedVariableID ,
209
- ):
210
- """Pickles an object to a zip file using external persistence.
211
-
212
- :param obj: The object to pickle.
213
- :type obj: object
214
-
215
- :param file_handler: The file handle to save the object to.
216
- :type file_handler: file
217
-
218
- :param protocol: The pickling protocol to use. Unlike Python's built-in
219
- pickle, the default is set to `2` instead of 0 for Python 2. The
220
- Python 3 default (level 3) is maintained.
221
- :type protocol: int, optional
222
-
223
- :param persistent_id: The callable that persists certain objects in the
224
- object hierarchy to separate files inside of the zip file. For example,
225
- :class:`PersistentNdarrayID` saves any :class:`numpy.ndarray` to a
226
- separate NPY file inside of the zip file.
227
- :type persistent_id: callable
228
-
229
- .. versionadded:: 0.8
230
-
231
- .. note::
232
- The final file is simply a zipped file containing at least one file,
233
- `pkl`, which contains the pickled object. It can contain any other
234
- number of external objects. Note that the zip files are compatible with
235
- NumPy's :func:`numpy.load` function.
236
-
237
- >>> import pytensor
238
- >>> foo_1 = pytensor.shared(0, name='foo')
239
- >>> foo_2 = pytensor.shared(1, name='foo')
240
- >>> with open('model.zip', 'wb') as f:
241
- ... dump((foo_1, foo_2, np.array(2)), f)
242
- >>> list(np.load('model.zip').keys())
243
- ['foo', 'foo_2', 'array_0', 'pkl']
244
- >>> np.load('model.zip')['foo']
245
- array(0)
246
- >>> with open('model.zip', 'rb') as f:
247
- ... foo_1, foo_2, array = load(f)
248
- >>> array
249
- array(2)
250
-
251
- """
252
- with closing (
253
- zipfile .ZipFile (file_handler , "w" , zipfile .ZIP_DEFLATED , allowZip64 = True )
254
- ) as zip_file :
255
-
256
- def func (f ):
257
- p = pickle .Pickler (f , protocol = protocol )
258
- p .persistent_id = persistent_id (zip_file )
259
- p .dump (obj )
260
-
261
- zipadd (func , zip_file , "pkl" )
262
-
263
-
264
- def load (f , persistent_load = PersistentNdarrayLoad ):
265
- """Load a file that was dumped to a zip file.
266
-
267
- :param f: The file handle to the zip file to load the object from.
268
- :type f: file
269
-
270
- :param persistent_load: The persistent loading function to use for
271
- unpickling. This must be compatible with the `persistent_id` function
272
- used when pickling.
273
- :type persistent_load: callable, optional
274
-
275
- .. versionadded:: 0.8
276
- """
277
- with closing (zipfile .ZipFile (f , "r" )) as zip_file :
278
- p = pickle .Unpickler (BytesIO (zip_file .open ("pkl" ).read ()))
279
- p .persistent_load = persistent_load (zip_file )
280
- return p .load ()
281
-
282
-
283
- def zipadd (func , zip_file , name ):
284
- """Calls a function with a file object, saving it to a zip file.
285
-
286
- :param func: The function to call.
287
- :type func: callable
288
-
289
- :param zip_file: The zip file that `func` should write its data to.
290
- :type zip_file: :class:`zipfile.ZipFile`
291
-
292
- :param name: The name of the file inside of the zipped archive that `func`
293
- should save its data to.
294
- :type name: str
295
-
296
- """
297
- with tempfile .NamedTemporaryFile ("wb" , delete = False ) as temp_file :
298
- func (temp_file )
299
- temp_file .close ()
300
- zip_file .write (temp_file .name , arcname = name )
301
- if os .path .isfile (temp_file .name ):
302
- os .remove (temp_file .name )
0 commit comments