3333
3434import contextlib
3535import functools
36+ import inspect
3637import io
3738import logging
3839import os
@@ -216,29 +217,40 @@ def __init_subclass__(cls, **kwargs):
216217
217218class _TensorizerUnpickler (pickle .Unpickler ):
218219 __filename : Optional [_tensorizer_file_obj_type ]
219- __tensors : list
220+ __has_tensors : bool
221+ __tensors : Optional [list ]
222+ __cached_super_load : Optional [callable ]
220223
221224 def __init__ (self , * args , ** kwargs ):
222225 super ().__init__ (* args , ** kwargs )
223226 self .__filename = _tensorizer_loading_filename .get ()
224- self .__tensors = []
227+ self .__has_tensors = self .__filename is not None
228+ self .__tensors = None
229+ self .__cached_super_load = None
225230
226231 def load (self ):
227- if self .__filename is not None :
228- kwargs = _tensorizer_deserializer_kwargs .get ()
229- if kwargs is None :
230- kwargs = {}
231- if (load_func := _load_wrapper_load_func .get ()) is None :
232- with TensorDeserializer (
233- self .__filename , ** kwargs
234- ) as deserializer :
235- self .__tensors = deserializer .tree ()
236- else :
237- self .__tensors = list (load_func (self .__filename , kwargs ))
238232 try :
239233 return super ().load ()
240234 finally :
241- self .__tensors .clear ()
235+ if self .__tensors is not None :
236+ self .__tensors .clear ()
237+ self .__tensors = None
238+
239+ def __load_tensors (self ) -> None :
240+ # Load and cache tensors from a sidecar file
241+ if self .__tensors is not None :
242+ return
243+ elif not self .__has_tensors :
244+ raise RuntimeError ("Tried to load tensors without a path" )
245+ kwargs = _tensorizer_deserializer_kwargs .get ()
246+ if kwargs is None :
247+ kwargs = {}
248+ if (load_func := _load_wrapper_load_func .get ()) is None :
249+ with TensorDeserializer (self .__filename , ** kwargs ) as deserializer :
250+ self .__tensors = deserializer .tree ()
251+ else :
252+ self .__tensors = list (load_func (self .__filename , kwargs ))
253+ assert self .__tensors is not None
242254
243255 @staticmethod
244256 def __tensor_to_storage (
@@ -254,9 +266,48 @@ def __tensor_to_storage(
254266 wrap_storage = tensor .untyped_storage (), dtype = dtype , _internal = True
255267 )
256268
269+ def __get_storage (self , idx : int , dtype : Optional [torch .dtype ]):
270+ # This will load all tensors the first time a "TensorizerPickler"
271+ # persistent_id is encountered, indicating that this was a file
272+ # created by a _TensorizerPickler. Deferring it to this point
273+ # will avoid trying to engage the load logic on .pt files
274+ # that were NOT created by a _TensorizerPickler, where there
275+ # is probably no corresponding .tensors file anyway, where trying
276+ # to load that would fail.
277+ if self .__tensors is None :
278+ self .__load_tensors ()
279+ tensor_view = self .__tensors [idx ]
280+ return self .__tensor_to_storage (tensor_view , dtype )
281+
282+ @property
283+ def __super_load (self ) -> Callable [[Any ], Any ]:
284+ if self .__cached_super_load is not None :
285+ return self .__cached_super_load
286+ super_load = super ().persistent_load
287+ super_load_func = getattr (super_load , "__func__" , super_load )
288+ # Evil Python behaviour can make the super method equal this method
289+ # prior to Python 3.13, so check for that to avoid accidental recursion.
290+ # _is_load_wrapper is set on dynamically-created wrappers
291+ # that ultimately recurse back to this function; avoid those too.
292+ if super_load_func == _TensorizerUnpickler .persistent_load or getattr (
293+ super_load_func , "_is_load_wrapper" , False
294+ ):
295+ # To avoid recursing forever, just raise the
296+ # default error from pickle.Unpickler instead
297+ self .__cached_super_load = self .__fallback_super_load
298+ else :
299+ # Will probably just throw an error,
300+ # but could redirect to a sibling class
301+ self .__cached_super_load = super_load
302+ return self .__cached_super_load
303+
304+ @staticmethod
305+ def __fallback_super_load (_pid ):
306+ raise pickle .UnpicklingError ("unsupported persistent id encountered" )
307+
257308 def persistent_load (self , pid ):
258309 if (
259- self .__filename is not None
310+ self .__has_tensors
260311 and isinstance (pid , tuple )
261312 and pid [0 ] == "TensorizerPickler"
262313 and len (pid ) >= 3
@@ -269,32 +320,58 @@ def persistent_load(self, pid):
269320 object_type = pid [2 ]
270321 if object_type == "storage" :
271322 idx , dtype = pid [3 :]
272- tensor_view = self .__tensors [idx ]
273- return self .__tensor_to_storage (tensor_view , dtype )
323+ return self .__get_storage (idx , dtype )
274324 else :
275325 raise pickle .UnpicklingError (
276326 f"Unsupported TensorizerPickler object type ({ object_type } )"
277327 )
278328 else :
279- # Will probably just throw an error
280- return super ().persistent_load (pid )
329+ return self .__super_load (pid )
281330
282331 @staticmethod
283332 def __wrap_persistent_load (persistent_load_func : callable ):
284333
285334 @functools .wraps (persistent_load_func )
286335 def _persistent_load (self , pid ):
287336 try :
288- return super (self .__class__ , self ).persistent_load (pid )
337+ if self .__class__ is _TensorizerUnpickler :
338+ # For instances of this class, call this class's method
339+ return self .__class__ .persistent_load (self , pid )
340+ else :
341+ # For subclasses, defer to the super method
342+ return super (self .__class__ , self ).persistent_load (pid )
289343 except pickle .UnpicklingError :
290344 pass
291- return persistent_load_func (self , pid )
345+ # This is being set on an instance, not the class,
346+ # so this wouldn't expect to be passed self as well,
347+ # as it is not an unbound method here
348+ return persistent_load_func (pid )
292349
293350 return _persistent_load
294351
295352 def __setattr__ (self , key , value ):
296353 if key == "persistent_load" :
297- value = self .__wrap_persistent_load (value )
354+ # If this method is being overridden dynamically, modify it
355+ # to defer to the persistent_load method from this class first
356+ wrapped_func = self .__wrap_persistent_load (value )
357+ # Mark this as a wrapper for recursion detection later on
358+ wrapped_func ._is_load_wrapper = True
359+ value = types .MethodType (wrapped_func , self )
360+ # Necessary witchcraft prior to Python 3.13:
361+ # pickle.Unpickler may internally cache persistent_load functions,
362+ # and it would normally update the cached value using a PyGetSetDef
363+ # descriptor, but having a class in the inheritance hierarchy
364+ # that defines persistent_load as a non-descriptor prevents
365+ # attribute updates from reaching that descriptor's set method,
366+ # so the cached value that the unpickler actually uses isn't
367+ # properly updated, even though the Python object shows it as being
368+ # updated. We can force this update to propagate to that descriptor
369+ # by manipulating it directly.
370+ if (
371+ pickle .Unpickler in self .__class__ .__mro__
372+ and inspect .isgetsetdescriptor (pickle .Unpickler .persistent_load )
373+ ):
374+ pickle .Unpickler .persistent_load .__set__ (self , value )
298375 super ().__setattr__ (key , value )
299376
300377 def __init_subclass__ (cls , ** kwargs ):
0 commit comments