Skip to content

Commit 816fd32

Browse files
committed
fix(torch_compat): Fix loading regular files with tensorizer_loading
1 parent e623c04 commit 816fd32

File tree

2 files changed

+124
-22
lines changed

2 files changed

+124
-22
lines changed

tensorizer/torch_compat.py

Lines changed: 99 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import contextlib
3535
import functools
36+
import inspect
3637
import io
3738
import logging
3839
import os
@@ -216,29 +217,40 @@ def __init_subclass__(cls, **kwargs):
216217

217218
class _TensorizerUnpickler(pickle.Unpickler):
218219
__filename: Optional[_tensorizer_file_obj_type]
219-
__tensors: list
220+
__has_tensors: bool
221+
__tensors: Optional[list]
222+
__cached_super_load: Optional[callable]
220223

221224
def __init__(self, *args, **kwargs):
222225
super().__init__(*args, **kwargs)
223226
self.__filename = _tensorizer_loading_filename.get()
224-
self.__tensors = []
227+
self.__has_tensors = self.__filename is not None
228+
self.__tensors = None
229+
self.__cached_super_load = None
225230

226231
def load(self):
227-
if self.__filename is not None:
228-
kwargs = _tensorizer_deserializer_kwargs.get()
229-
if kwargs is None:
230-
kwargs = {}
231-
if (load_func := _load_wrapper_load_func.get()) is None:
232-
with TensorDeserializer(
233-
self.__filename, **kwargs
234-
) as deserializer:
235-
self.__tensors = deserializer.tree()
236-
else:
237-
self.__tensors = list(load_func(self.__filename, kwargs))
238232
try:
239233
return super().load()
240234
finally:
241-
self.__tensors.clear()
235+
if self.__tensors is not None:
236+
self.__tensors.clear()
237+
self.__tensors = None
238+
239+
def __load_tensors(self) -> None:
240+
# Load and cache tensors from a sidecar file
241+
if self.__tensors is not None:
242+
return
243+
elif not self.__has_tensors:
244+
raise RuntimeError("Tried to load tensors without a path")
245+
kwargs = _tensorizer_deserializer_kwargs.get()
246+
if kwargs is None:
247+
kwargs = {}
248+
if (load_func := _load_wrapper_load_func.get()) is None:
249+
with TensorDeserializer(self.__filename, **kwargs) as deserializer:
250+
self.__tensors = deserializer.tree()
251+
else:
252+
self.__tensors = list(load_func(self.__filename, kwargs))
253+
assert self.__tensors is not None
242254

243255
@staticmethod
244256
def __tensor_to_storage(
@@ -254,9 +266,48 @@ def __tensor_to_storage(
254266
wrap_storage=tensor.untyped_storage(), dtype=dtype, _internal=True
255267
)
256268

269+
def __get_storage(self, idx: int, dtype: Optional[torch.dtype]):
270+
# This will load all tensors the first time a "TensorizerPickler"
271+
# persistent_id is encountered, indicating that this was a file
272+
# created by a _TensorizerPickler. Deferring it to this point
273+
# will avoid trying to engage the load logic on .pt files
274+
# that were NOT created by a _TensorizerPickler, where there
275+
# is probably no corresponding .tensors file anyway, where trying
276+
# to load that would fail.
277+
if self.__tensors is None:
278+
self.__load_tensors()
279+
tensor_view = self.__tensors[idx]
280+
return self.__tensor_to_storage(tensor_view, dtype)
281+
282+
@property
283+
def __super_load(self) -> Callable[[Any], Any]:
284+
if self.__cached_super_load is not None:
285+
return self.__cached_super_load
286+
super_load = super().persistent_load
287+
super_load_func = getattr(super_load, "__func__", super_load)
288+
# Evil Python behaviour can make the super method equal this method
289+
# prior to Python 3.13, so check for that to avoid accidental recursion.
290+
# _is_load_wrapper is set on dynamically-created wrappers
291+
# that ultimately recurse back to this function; avoid those too.
292+
if super_load_func == _TensorizerUnpickler.persistent_load or getattr(
293+
super_load_func, "_is_load_wrapper", False
294+
):
295+
# To avoid recursing forever, just raise the
296+
# default error from pickle.Unpickler instead
297+
self.__cached_super_load = self.__fallback_super_load
298+
else:
299+
# Will probably just throw an error,
300+
# but could redirect to a sibling class
301+
self.__cached_super_load = super_load
302+
return self.__cached_super_load
303+
304+
@staticmethod
305+
def __fallback_super_load(_pid):
306+
raise pickle.UnpicklingError("unsupported persistent id encountered")
307+
257308
def persistent_load(self, pid):
258309
if (
259-
self.__filename is not None
310+
self.__has_tensors
260311
and isinstance(pid, tuple)
261312
and pid[0] == "TensorizerPickler"
262313
and len(pid) >= 3
@@ -269,32 +320,58 @@ def persistent_load(self, pid):
269320
object_type = pid[2]
270321
if object_type == "storage":
271322
idx, dtype = pid[3:]
272-
tensor_view = self.__tensors[idx]
273-
return self.__tensor_to_storage(tensor_view, dtype)
323+
return self.__get_storage(idx, dtype)
274324
else:
275325
raise pickle.UnpicklingError(
276326
f"Unsupported TensorizerPickler object type ({object_type})"
277327
)
278328
else:
279-
# Will probably just throw an error
280-
return super().persistent_load(pid)
329+
return self.__super_load(pid)
281330

282331
@staticmethod
283332
def __wrap_persistent_load(persistent_load_func: callable):
284333

285334
@functools.wraps(persistent_load_func)
286335
def _persistent_load(self, pid):
287336
try:
288-
return super(self.__class__, self).persistent_load(pid)
337+
if self.__class__ is _TensorizerUnpickler:
338+
# For instances of this class, call this class's method
339+
return self.__class__.persistent_load(self, pid)
340+
else:
341+
# For subclasses, defer to the super method
342+
return super(self.__class__, self).persistent_load(pid)
289343
except pickle.UnpicklingError:
290344
pass
291-
return persistent_load_func(self, pid)
345+
# This is being set on an instance, not the class,
346+
# so this wouldn't expect to be passed self as well,
347+
# as it is not an unbound method here
348+
return persistent_load_func(pid)
292349

293350
return _persistent_load
294351

295352
def __setattr__(self, key, value):
296353
if key == "persistent_load":
297-
value = self.__wrap_persistent_load(value)
354+
# If this method is being overridden dynamically, modify it
355+
# to defer to the persistent_load method from this class first
356+
wrapped_func = self.__wrap_persistent_load(value)
357+
# Mark this as a wrapper for recursion detection later on
358+
wrapped_func._is_load_wrapper = True
359+
value = types.MethodType(wrapped_func, self)
360+
# Necessary witchcraft prior to Python 3.13:
361+
# pickle.Unpickler may internally cache persistent_load functions,
362+
# and it would normally update the cached value using a PyGetSetDef
363+
# descriptor, but having a class in the inheritance hierarchy
364+
# that defines persistent_load as a non-descriptor prevents
365+
# attribute updates from reaching that descriptor's set method,
366+
# so the cached value that the unpickler actually uses isn't
367+
# properly updated, even though the Python object shows it as being
368+
# updated. We can force this update to propagate to that descriptor
369+
# by manipulating it directly.
370+
if (
371+
pickle.Unpickler in self.__class__.__mro__
372+
and inspect.isgetsetdescriptor(pickle.Unpickler.persistent_load)
373+
):
374+
pickle.Unpickler.persistent_load.__set__(self, value)
298375
super().__setattr__(key, value)
299376

300377
def __init_subclass__(cls, **kwargs):

tests/test_torch_compat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,28 @@ def check_saved_alt() -> None:
621621
pass
622622
check_sd(torch.load(self.pt_path))
623623
cleanup()
624+
625+
def test_save_load_without_tensors(self):
626+
original = [1, "2", 3.0, torch.device("meta")]
627+
628+
with tensorizer_saving():
629+
torch.save(original, self.pt_path)
630+
631+
self.assertTrue(self.pt_path.is_file())
632+
self.assertFalse(self.tensors_path.exists())
633+
634+
with tensorizer_loading():
635+
loaded = torch.load(self.pt_path)
636+
637+
self.assertListEqual(original, loaded)
638+
639+
def test_load_with_regular_file(self):
640+
torch.save(self.model, self.pt_path)
641+
642+
self.assertTrue(self.pt_path.is_file())
643+
self.assertFalse(self.tensors_path.exists())
644+
645+
with tensorizer_loading():
646+
loaded_model = torch.load(self.pt_path)
647+
648+
self.check_model(loaded_model)

0 commit comments

Comments
 (0)