Skip to content

Commit 084fbf1

Browse files
committed
Remove most of misc/pkl_utils.py because unused
1 parent 857cda0 commit 084fbf1

File tree

3 files changed

+14
-282
lines changed

3 files changed

+14
-282
lines changed

pytensor/compile/function/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import re
33
import traceback as tb
4+
from pathlib import Path
45

56
from pytensor.compile.function.pfunc import pfunc
67
from pytensor.compile.function.types import orig_function
@@ -13,7 +14,7 @@
1314

1415

1516
def function_dump(
16-
filename,
17+
filename: str | Path,
1718
inputs,
1819
outputs=None,
1920
mode=None,
@@ -26,7 +27,7 @@ def function_dump(
2627
allow_input_downcast=None,
2728
profile=None,
2829
on_unused_input=None,
29-
extra_tag_to_remove=None,
30+
extra_tag_to_remove: str | None = None,
3031
):
3132
"""
3233
This is helpful to make a reproducible case for problems during PyTensor
@@ -59,7 +60,7 @@ def function_dump(
5960
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
6061
6162
"""
62-
assert isinstance(filename, str)
63+
filename = Path(filename)
6364
d = dict(
6465
inputs=inputs,
6566
outputs=outputs,
@@ -74,7 +75,7 @@ def function_dump(
7475
profile=profile,
7576
on_unused_input=on_unused_input,
7677
)
77-
with open(filename, "wb") as f:
78+
with filename.open("wb") as f:
7879
import pytensor.misc.pkl_utils
7980

8081
pickler = pytensor.misc.pkl_utils.StripPickler(

pytensor/misc/pkl_utils.py

Lines changed: 8 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,12 @@
55
unit tests or regression tests.
66
"""
77

8-
import os
98
import pickle
109
import sys
11-
import tempfile
12-
import zipfile
13-
from collections import Counter
14-
from contextlib import closing
15-
from io import BytesIO
16-
from pickle import HIGHEST_PROTOCOL
17-
18-
import numpy as np
1910

2011
import pytensor
2112

2213

23-
try:
24-
from pickle import DEFAULT_PROTOCOL
25-
except ImportError:
26-
DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
27-
28-
from pytensor.compile.sharedvalue import SharedVariable
29-
30-
3114
__docformat__ = "restructuredtext en"
3215
__authors__ = "Pascal Lamblin " "PyMC Developers " "PyTensor Developers "
3316
__copyright__ = "Copyright 2013, Universite de Montreal"
@@ -49,16 +32,18 @@ class StripPickler(Pickler):
4932
5033
..code-block:: python
5134
52-
fn_args = dict(inputs=inputs,
53-
outputs=outputs,
54-
updates=updates)
55-
dest_pkl = 'my_test.pkl'
56-
with open(dest_pkl, 'wb') as f:
35+
fn_args = {
36+
"inputs": inputs,
37+
"outputs": outputs,
38+
"updates": updates,
39+
}
40+
dest_pkl = "my_test.pkl"
41+
with Path(dest_pkl).open("wb") as f:
5742
strip_pickler = StripPickler(f, protocol=-1)
5843
strip_pickler.dump(fn_args)
5944
"""
6045

61-
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
46+
def __init__(self, file, protocol: int = 0, extra_tag_to_remove: str | None = None):
6247
# Can't use super as Pickler isn't a new style class
6348
super().__init__(file, protocol)
6449
self.tag_to_remove = ["trace", "test_value"]
@@ -77,226 +62,3 @@ def save(self, obj):
7762
del obj.__dict__["__doc__"]
7863

7964
return Pickler.save(self, obj)
80-
81-
82-
class PersistentNdarrayID:
83-
"""Persist ndarrays in an object by saving them to a zip file.
84-
85-
:param zip_file: A zip file handle that the NumPy arrays will be saved to.
86-
:type zip_file: :class:`zipfile.ZipFile`
87-
88-
89-
.. note:
90-
The convention for persistent ids given by this class and its derived
91-
classes is that the name should take the form `type.name` where `type`
92-
can be used by the persistent loader to determine how to load the
93-
object, while `name` is human-readable and as descriptive as possible.
94-
95-
"""
96-
97-
def __init__(self, zip_file):
98-
self.zip_file = zip_file
99-
self.count = 0
100-
self.seen = {}
101-
102-
def _resolve_name(self, obj):
103-
"""Determine the name the object should be saved under."""
104-
name = f"array_{self.count}"
105-
self.count += 1
106-
return name
107-
108-
def __call__(self, obj):
109-
if isinstance(obj, np.ndarray):
110-
if id(obj) not in self.seen:
111-
112-
def write_array(f):
113-
np.lib.format.write_array(f, obj)
114-
115-
name = self._resolve_name(obj)
116-
zipadd(write_array, self.zip_file, name)
117-
self.seen[id(obj)] = f"ndarray.{name}"
118-
return self.seen[id(obj)]
119-
120-
121-
class PersistentSharedVariableID(PersistentNdarrayID):
122-
"""Uses shared variable names when persisting to zip file.
123-
124-
If a shared variable has a name, this name is used as the name of the
125-
NPY file inside of the zip file. NumPy arrays that aren't matched to a
126-
shared variable are persisted as usual (i.e. `array_0`, `array_1`,
127-
etc.)
128-
129-
:param allow_unnamed: Allow shared variables without a name to be
130-
persisted. Defaults to ``True``.
131-
:type allow_unnamed: bool, optional
132-
133-
:param allow_duplicates: Allow multiple shared variables to have the same
134-
name, in which case they will be numbered e.g. `x`, `x_2`, `x_3`, etc.
135-
Defaults to ``True``.
136-
:type allow_duplicates: bool, optional
137-
138-
:raises ValueError
139-
If an unnamed shared variable is encountered and `allow_unnamed` is
140-
``False``, or if two shared variables have the same name, and
141-
`allow_duplicates` is ``False``.
142-
143-
"""
144-
145-
def __init__(self, zip_file, allow_unnamed=True, allow_duplicates=True):
146-
super().__init__(zip_file)
147-
self.name_counter = Counter()
148-
self.ndarray_names = {}
149-
self.allow_unnamed = allow_unnamed
150-
self.allow_duplicates = allow_duplicates
151-
152-
def _resolve_name(self, obj):
153-
if id(obj) in self.ndarray_names:
154-
name = self.ndarray_names[id(obj)]
155-
count = self.name_counter[name]
156-
self.name_counter[name] += 1
157-
if count:
158-
if not self.allow_duplicates:
159-
raise ValueError(
160-
f"multiple shared variables with the name `{name}` found"
161-
)
162-
name = f"{name}_{count + 1}"
163-
return name
164-
return super()._resolve_name(obj)
165-
166-
def __call__(self, obj):
167-
if isinstance(obj, SharedVariable):
168-
if obj.name:
169-
if obj.name == "pkl":
170-
ValueError("can't pickle shared variable with name `pkl`")
171-
self.ndarray_names[id(obj.container.storage[0])] = obj.name
172-
elif not self.allow_unnamed:
173-
raise ValueError(f"unnamed shared variable, {obj}")
174-
return super().__call__(obj)
175-
176-
177-
class PersistentNdarrayLoad:
178-
"""Load NumPy arrays that were persisted to a zip file when pickling.
179-
180-
:param zip_file: The zip file handle in which the NumPy arrays are saved.
181-
:type zip_file: :class:`zipfile.ZipFile`
182-
183-
"""
184-
185-
def __init__(self, zip_file):
186-
self.zip_file = zip_file
187-
self.cache = {}
188-
189-
def __call__(self, persid):
190-
array_type, name = persid.split(".")
191-
del array_type
192-
# array_type was used for switching gpu/cpu arrays
193-
# it is better to put these into sublclasses properly
194-
# this is more work but better logic
195-
if name in self.cache:
196-
return self.cache[name]
197-
ret = None
198-
with self.zip_file.open(name) as f:
199-
ret = np.lib.format.read_array(f)
200-
self.cache[name] = ret
201-
return ret
202-
203-
204-
def dump(
205-
obj,
206-
file_handler,
207-
protocol=DEFAULT_PROTOCOL,
208-
persistent_id=PersistentSharedVariableID,
209-
):
210-
"""Pickles an object to a zip file using external persistence.
211-
212-
:param obj: The object to pickle.
213-
:type obj: object
214-
215-
:param file_handler: The file handle to save the object to.
216-
:type file_handler: file
217-
218-
:param protocol: The pickling protocol to use. Unlike Python's built-in
219-
pickle, the default is set to `2` instead of 0 for Python 2. The
220-
Python 3 default (level 3) is maintained.
221-
:type protocol: int, optional
222-
223-
:param persistent_id: The callable that persists certain objects in the
224-
object hierarchy to separate files inside of the zip file. For example,
225-
:class:`PersistentNdarrayID` saves any :class:`numpy.ndarray` to a
226-
separate NPY file inside of the zip file.
227-
:type persistent_id: callable
228-
229-
.. versionadded:: 0.8
230-
231-
.. note::
232-
The final file is simply a zipped file containing at least one file,
233-
`pkl`, which contains the pickled object. It can contain any other
234-
number of external objects. Note that the zip files are compatible with
235-
NumPy's :func:`numpy.load` function.
236-
237-
>>> import pytensor
238-
>>> foo_1 = pytensor.shared(0, name='foo')
239-
>>> foo_2 = pytensor.shared(1, name='foo')
240-
>>> with open('model.zip', 'wb') as f:
241-
... dump((foo_1, foo_2, np.array(2)), f)
242-
>>> list(np.load('model.zip').keys())
243-
['foo', 'foo_2', 'array_0', 'pkl']
244-
>>> np.load('model.zip')['foo']
245-
array(0)
246-
>>> with open('model.zip', 'rb') as f:
247-
... foo_1, foo_2, array = load(f)
248-
>>> array
249-
array(2)
250-
251-
"""
252-
with closing(
253-
zipfile.ZipFile(file_handler, "w", zipfile.ZIP_DEFLATED, allowZip64=True)
254-
) as zip_file:
255-
256-
def func(f):
257-
p = pickle.Pickler(f, protocol=protocol)
258-
p.persistent_id = persistent_id(zip_file)
259-
p.dump(obj)
260-
261-
zipadd(func, zip_file, "pkl")
262-
263-
264-
def load(f, persistent_load=PersistentNdarrayLoad):
265-
"""Load a file that was dumped to a zip file.
266-
267-
:param f: The file handle to the zip file to load the object from.
268-
:type f: file
269-
270-
:param persistent_load: The persistent loading function to use for
271-
unpickling. This must be compatible with the `persistent_id` function
272-
used when pickling.
273-
:type persistent_load: callable, optional
274-
275-
.. versionadded:: 0.8
276-
"""
277-
with closing(zipfile.ZipFile(f, "r")) as zip_file:
278-
p = pickle.Unpickler(BytesIO(zip_file.open("pkl").read()))
279-
p.persistent_load = persistent_load(zip_file)
280-
return p.load()
281-
282-
283-
def zipadd(func, zip_file, name):
284-
"""Calls a function with a file object, saving it to a zip file.
285-
286-
:param func: The function to call.
287-
:type func: callable
288-
289-
:param zip_file: The zip file that `func` should write its data to.
290-
:type zip_file: :class:`zipfile.ZipFile`
291-
292-
:param name: The name of the file inside of the zipped archive that `func`
293-
should save its data to.
294-
:type name: str
295-
296-
"""
297-
with tempfile.NamedTemporaryFile("wb", delete=False) as temp_file:
298-
func(temp_file)
299-
temp_file.close()
300-
zip_file.write(temp_file.name, arcname=name)
301-
if os.path.isfile(temp_file.name):
302-
os.remove(temp_file.name)

tests/misc/test_pkl_utils.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,10 @@
22
import shutil
33
from tempfile import mkdtemp
44

5-
import numpy as np
6-
7-
import pytensor
8-
from pytensor.misc.pkl_utils import StripPickler, dump, load
5+
from pytensor.misc.pkl_utils import StripPickler
96
from pytensor.tensor.type import matrix
107

118

12-
class TestDumpLoad:
13-
def setup_method(self):
14-
# Work in a temporary directory to avoid cluttering the repository
15-
self.origdir = os.getcwd()
16-
self.tmpdir = mkdtemp()
17-
os.chdir(self.tmpdir)
18-
19-
def teardown_method(self):
20-
# Get back to the original dir, and delete the temporary one
21-
os.chdir(self.origdir)
22-
if self.tmpdir is not None:
23-
shutil.rmtree(self.tmpdir)
24-
25-
def test_dump_zip_names(self):
26-
foo_1 = pytensor.shared(0, name="foo")
27-
foo_2 = pytensor.shared(1, name="foo")
28-
foo_3 = pytensor.shared(2, name="foo")
29-
with open("model.zip", "wb") as f:
30-
dump((foo_1, foo_2, foo_3, np.array(3)), f)
31-
keys = list(np.load("model.zip").keys())
32-
assert keys == ["foo", "foo_2", "foo_3", "array_0", "pkl"]
33-
foo_3 = np.load("model.zip")["foo_3"]
34-
assert foo_3 == np.array(2)
35-
with open("model.zip", "rb") as f:
36-
foo_1, foo_2, foo_3, array = load(f)
37-
assert array == np.array(3)
38-
39-
409
class TestStripPickler:
4110
def setup_method(self):
4211
# Work in a temporary directory to avoid cluttering the repository

0 commit comments

Comments
 (0)