diff --git a/CHANGELOG.md b/CHANGELOG.md index f9b15c3..c1db8ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- `tensorizer.torch_compat` is a new module for using `tensorizer` as a backend + for handling tensor data during standard `torch.save` and `torch.load` calls + - To use `tensorizer` as a backend for `torch.save`, + wrap the call in the `tensorizer_saving` context manager + - The file created must then be loaded using `tensorizer_loading` + - To use `tensorizer` as a backend for `torch.load`, + wrap the call in the `tensorizer_loading` context manager + - The file to load must have been created using `tensorizer_saving` + ## [2.10.1] - 2025-06-27 ### Fixed @@ -472,6 +485,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `get_gpu_name` - `no_init_or_tensor` +[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.10.1...HEAD [2.10.1]: https://github.com/coreweave/tensorizer/compare/v2.10.0...v2.10.1 [2.10.0]: https://github.com/coreweave/tensorizer/compare/v2.9.3...v2.10.0 [2.9.3]: https://github.com/coreweave/tensorizer/compare/v2.9.2...v2.9.3 diff --git a/README.md b/README.md index f650ae8..f4ddde5 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,197 @@ An example command line tool to add or remove encryption from existing serialized models is also available as [examples/encryption.py](examples/encrypt_existing.py). +## PyTorch Compatibility + +`tensorizer`'s `TensorSerializer` and `TensorDeserializer` classes are designed +to be able to replace the use of `torch.save` and `torch.load` in model saving +and loading pipelines, however, they are not drop-in replacements. The API for +serialization and deserialization with `tensorizer` offer more parameters to +control performance, resource usage, and additional features like encryption, +so they are invoked differently. +For drop-in replacements, see the next section. + +The examples below show example usages of +`torch.save` and `torch.load`, and how they may be replaced with `tensorizer` +serialization. + +```py +from tensorizer import TensorDeserializer, TensorSerializer +import torch + +model: torch.nn.Module = ... + +# Saving with torch.save +state_dict = model.state_dict() +torch.save(state_dict, "model.pt") + +# Loading with torch.load +state_dict = torch.load("model.pt", map_location="cuda:0") +model.load_state_dict(state_dict) + +# Saving with tensorizer.TensorSerializer +state_dict = model.state_dict() +serializer = TensorSerializer("model.tensors") +serializer.write_state_dict(state_dict) +serializer.close() + +# Loading with tensorizer.TensorDeserializer +with TensorDeserializer("model.tensors", device="cuda:0") as state_dict: + model.load_state_dict(state_dict) +``` + +> [!NOTE] +> +> `TensorDeserializer` is a context manager because it supports lazy-loading, +> where the context controls how long its source file will remain open to read +> more tensors. This behaviour is optional and can be engaged by using +> `TensorDeserializer(..., lazy_load=True)`. + +### Drop-In PyTorch Compatibility Layer, `tensorizer.torch_compat` + +Note that, as `tensorizer` only serializes tensors and not other Python types, +it is more similar to `safetensors` than to `torch`'s own saving, as `torch` +bases its serialization on the `pickle` module, which allows serialization of +arbitrary Python objects. + +The `tensorizer.torch_compat` module exists to address this and another common +integration challenge: +- Use case 1: You need to serialize Python objects other than tensors, + like `torch.save` does. +- Use case 2: You need to adapt existing code that uses `torch.save` internally + where it is not easy to swap out, like in an external framework or library. + +**`tensorizer.torch_compat` enables calls to `torch.save` and `torch.load` +to use `tensorizer` as a backend for the serialization and deserialization +of tensor data, separate from other data being serialized.** + +The interface to using `tensorizer.torch_compat` is through its two context +managers, `tensorizer_saving` and `tensorizer_loading`. These take similar +arguments to the `TensorSerializer` and `TensorDeserializer` classes, +respectively, and temporarily swap out the `torch.save` and `torch.load` +functions to ones with special behaviour while their context is active. +Saving this way produces two files, one for tensors, and one for all other data. + +```py +import torch +from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving + +model: torch.nn.Module = ... + +state_dict = model.state_dict() + +# Saving with torch.save, internally using tensorizer.TensorSerializer +with tensorizer_saving("model.pt.tensors"): + torch.save(state_dict, "model.pt") + +# Loading with torch.load, internally using tensorizer.TensorDeserializer +with tensorizer_loading("model.pt.tensors", device="cuda:0"): + state_dict = torch.load("model.pt") +model.load_state_dict(state_dict) +``` + +For existing code that uses `torch.save` or `torch.load` internally, the +recommended usage pattern is to wrap the relevant section of code in one of +these context managers so that it can use `tensorizer` automatically. + +For instance, with a `transformers.Trainer` object, part of adapting it to +use `tensorizer` may be: + +```py +from tensorizer.torch_compat import tensorizer_saving + +with tensorizer_saving(): + # In case this module saves references to torch.save at import time + import transformers + +trainer: transformers.Trainer = ... + +with tensorizer_saving(): + # This method may call torch.save internally at some point, + # so activating this context around it will intercept it when it does + trainer.train() +``` + +#### `torch_compat` Usage Considerations + +If the filename to use is difficult to determine in advance, the first +`file_obj` argument to `tensorizer_loading` and `tensorizer_saving` is allowed +to be a callback that receives the path passed to `torch.save` and returns +a place to output the sidecar `.tensors` file. + +The `.tensors` path can be anything supported normally in `tensorizer`, +including pre-opened file-like objects and `s3://` URIs. +The default `file_obj` callback simply appends `.tensors` to the path. + +```py +import torch +from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving + + +def tensors_path(f: torch.types.FileLike) -> str | None: + if isinstance(f, str): + return f.replace(".pt", "-tensor-data.tensors", 1) + else: + # Returning None will save normally, without using tensorizer. + # This is useful for file-like objects like io.BytesIO, + # where sidecar files don't make sense. + return None + + +model: torch.nn.Module = ... +state_dict = model.state_dict() + +with tensorizer_saving(tensors_path): + # Will save to model.pt and model-tensor-data.tensors + torch.save(state_dict, "model.pt") + +with tensorizer_loading(tensors_path, device="cuda:0"): + # Will load from model.pt and model-tensor-data.tensors + state_dict = torch.load("model.pt") +model.load_state_dict(state_dict) +``` + +The `tensorizer_saving` and `tensorizer_loading` contexts are also thread-safe +and async-safe, in that their effects are local to one thread and coroutine. +They may also be activated at the same time as each other, or even nested +to temporarily change the arguments one is using. + +> [!WARNING] +> +> Even though `tensorizer` itself only handles data and does not execute +> arbitrary code, `torch.load` still uses the `pickle` module internally. +> Loading untrusted `pickle` files **can** execute arbitrary code, so take +> appropriate precautions when using these wrappers. +> +> Additionally, for technical reasons, `torch.load(..., weights_only=True)` +> is incompatible with these wrappers. `weights_only` can be forced to `False` +> by using `tensorizer_loading(..., suppress_weights_only=True)`, +> but this disables some safety checks in `torch`, so this is opt-in only. + +Finally, since the `tensorizer_saving` and `tensorizer_loading` contexts +temporarily swap out the `torch.save` and `torch.load` functions, note that they +will not affect already-saved references to those functions, e.g.: + +```py +from tensorizer.torch_compat import tensorizer_saving +from torch import save as original_torch_save + +with tensorizer_saving(): + # This won't work, but torch.save(..., "model.pt") would work + original_torch_save(..., "model.pt") +``` + +This can sometimes be worked around by wrapping import blocks +in `tensorizer_saving` and/or `tensorizer_loading` as well. +The wrappers will behave the same as the default `torch.save` and `torch.load` +functions unless their respective contexts are active, so this will usually +have no side effects. + +For additional parameters, caveats, and advanced usage information, +refer to the docstrings for `tensorizer_saving` and `tensorizer_loading` in +the file [tensorizer/torch_compat.py](/tensorizer/torch_compat.py), +or view their function documentation inline in an IDE. + ## Benchmarks You can run your own benchmarks on CoreWeave or your own Kubernetes cluster diff --git a/tensorizer/__init__.py b/tensorizer/__init__.py index f29473b..f8fd24b 100644 --- a/tensorizer/__init__.py +++ b/tensorizer/__init__.py @@ -1,10 +1,11 @@ -from . import serialization, stream_io, utils +from . import serialization, stream_io, torch_compat, utils from ._version import __version__ from .serialization import * __all__ = [ *serialization.__all__, "stream_io", + "torch_compat", "utils", "protobuf", "tensors_pb2", diff --git a/tensorizer/_version.py b/tensorizer/_version.py index 565443f..fbe28c9 100644 --- a/tensorizer/_version.py +++ b/tensorizer/_version.py @@ -1 +1 @@ -__version__ = "2.10.1" +__version__ = "2.11.0a0" diff --git a/tensorizer/torch_compat.py b/tensorizer/torch_compat.py new file mode 100644 index 0000000..8345dc4 --- /dev/null +++ b/tensorizer/torch_compat.py @@ -0,0 +1,666 @@ +""" +Compatibility layer for using ``torch.save`` and ``torch.load`` with tensorizer +as a backend for the serialization of tensors and tensor storages. + +Author: Eta Syra + +Example: + An instance of ``torch.nn.Module`` can be serialized as follows:: + + import os + import torch + from tensorizer.torch_compat import ( + tensorizer_saving, tensorizer_loading + ) + + module: torch.nn.Module = ... + + with tensorizer_saving(): + torch.save(module, "module.pt") + + assert os.path.exists("module.pt") + assert os.path.exists("module.pt.tensors") + + with tensorizer_loading(device="cuda", num_readers=4): + deserialized_module = torch.load("module.pt") + + Both `tensorizer_saving` and `tensorizer_loading` can be passed keyword + arguments to be forwarded to a `TensorSerializer` and `TensorDeserializer` + object, respectively. They can also be given a ``file_obj`` argument + to control where they save the sidecar ``.tensors`` file containing + tensor data. +""" + +import contextlib +import functools +import inspect +import io +import logging +import os +import pickle +import threading +import types +import typing +from contextvars import ContextVar +from typing import Any, Callable, Final, Iterable, List, Optional, Tuple, Union + +import torch + +from .serialization import TensorDeserializer, TensorSerializer + +__all__ = ( + "tensorizer_saving", + "tensorizer_loading", +) + +logger = logging.getLogger(__name__) + +_tensorizer_file_obj_type: "typing.TypeAlias" = Union[ + io.BufferedIOBase, + io.RawIOBase, + typing.BinaryIO, + str, + bytes, + os.PathLike, + int, +] + +_wrapper_file_obj_type: "typing.TypeAlias" = Union[ + _tensorizer_file_obj_type, + Callable[[torch.types.FileLike], _tensorizer_file_obj_type], +] + +_save_func_type: "typing.TypeAlias" = Callable[ + [_tensorizer_file_obj_type, Iterable[torch.Tensor], dict], + Any, +] + +_load_func_type: "typing.TypeAlias" = Callable[ + [_tensorizer_file_obj_type, dict], Iterable[torch.Tensor] +] + +_storage_type: "typing.TypeAlias" = Union[ + torch.UntypedStorage, torch.TypedStorage +] + +_tensorizer_loading_filename: ContextVar[Optional[_wrapper_file_obj_type]] = ( + ContextVar("_tensorizer_loading_filename", default=None) +) +_tensorizer_saving_filename: ContextVar[Optional[_wrapper_file_obj_type]] = ( + ContextVar("_tensorizer_saving_filename", default=None) +) + +_tensorizer_deserializer_kwargs: ContextVar[Optional[dict]] = ContextVar( + "_tensorizer_deserializer_kwargs", default=None +) + +_tensorizer_serializer_kwargs: ContextVar[Optional[dict]] = ContextVar( + "_tensorizer_serializer_kwargs", default=None +) + + +def _storage_device(storage: _storage_type) -> torch.device: + if isinstance(storage, torch.TypedStorage): + return getattr(storage, "_untyped_storage", storage).device + else: + return storage.device + + +def _has_data(storage: _storage_type) -> bool: + maybe_untyped = getattr(storage, "_untyped_storage", storage) + return maybe_untyped.device.type != "meta" and maybe_untyped.data_ptr() != 0 + + +class _TensorizerPickler(pickle.Pickler): + __filename: Optional[_tensorizer_file_obj_type] + __tensors: List[torch.Tensor] + __tensor_ids: typing.Dict[Tuple[typing.Hashable, ...], int] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__filename = _tensorizer_saving_filename.get() + self.__tensors = [] + self.__tensor_ids = {} + + @staticmethod + def __tensor_key(tensor: torch.Tensor) -> Tuple[typing.Hashable, ...]: + return ( + tensor.data_ptr(), + tensor.dtype, + tensor.shape, + tensor.layout, + tensor.stride(), + ) + + def __register_tensor(self, tensor: torch.Tensor) -> int: + tensors = self.__tensors + new: int = len(tensors) + idx: int = self.__tensor_ids.setdefault(self.__tensor_key(tensor), new) + if idx is new: + tensors.append(tensor) + else: + tensors[idx] = tensor + return idx + + def dump(self, obj): + super().dump(obj) + if self.__tensors: + if self.__filename is None: + self.__tensors.clear() + self.__tensor_ids.clear() + return + kwargs = _tensorizer_serializer_kwargs.get() + if kwargs is None: + kwargs = {} + try: + if (save_func := _save_wrapper_save_func.get()) is None: + serializer = TensorSerializer(self.__filename, **kwargs) + serializer.write_state_dict(self.__tensors) + serializer.close() + else: + save_func(self.__filename, self.__tensors, kwargs) + finally: + # Don't call .clear() on self.__tensors in case it was saved + # somewhere by save_func + self.__tensors = [] + self.__tensor_ids.clear() + + @staticmethod + def __storage_to_tensor(storage: _storage_type) -> torch.Tensor: + # Convert a storage into an equivalent tensor + # for compatibility with a TensorSerializer + if not isinstance(storage, torch.UntypedStorage): + untyped = getattr(storage, "_untyped_storage", None) + if untyped is None: + untyped = storage.untyped() + storage = untyped + tensor = torch.empty( + (0,), dtype=torch.uint8, device=storage.device, requires_grad=False + ) + tensor.set_(storage) + return tensor + + def persistent_id(self, obj): + if ( + self.__filename is not None + and torch.is_storage(obj) + and _has_data(obj) + ): + tensor_view = self.__storage_to_tensor(obj) + dtype = getattr(obj, "dtype", None) + idx: int = self.__register_tensor(tensor_view) + return "TensorizerPickler", 0, "storage", idx, dtype + return None + + @staticmethod + def __wrap_persistent_id(persistent_id_func: callable): + @functools.wraps(persistent_id_func) + def _persistent_id(self, obj): + super_id = super(self.__class__, self).persistent_id(obj) + if super_id is not None: + return super_id + else: + return persistent_id_func(self, obj) + + return _persistent_id + + def __setattr__(self, key, value): + if key == "persistent_id": + value = self.__wrap_persistent_id(value) + super().__setattr__(key, value) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "persistent_id" in cls.__dict__: + cls.persistent_id = cls.__wrap_persistent_id(cls.persistent_id) + + +class _TensorizerUnpickler(pickle.Unpickler): + __filename: Optional[_tensorizer_file_obj_type] + __has_tensors: bool + __tensors: Optional[list] + __cached_super_load: Optional[callable] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__filename = _tensorizer_loading_filename.get() + self.__has_tensors = self.__filename is not None + self.__tensors = None + self.__cached_super_load = None + + def load(self): + try: + return super().load() + finally: + if self.__tensors is not None: + self.__tensors.clear() + self.__tensors = None + + def __load_tensors(self) -> None: + # Load and cache tensors from a sidecar file + if self.__tensors is not None: + return + elif not self.__has_tensors: + raise RuntimeError("Tried to load tensors without a path") + kwargs = _tensorizer_deserializer_kwargs.get() + if kwargs is None: + kwargs = {} + if (load_func := _load_wrapper_load_func.get()) is None: + with TensorDeserializer(self.__filename, **kwargs) as deserializer: + self.__tensors = deserializer.tree() + else: + self.__tensors = list(load_func(self.__filename, kwargs)) + assert self.__tensors is not None + + @staticmethod + def __tensor_to_storage( + tensor: torch.Tensor, dtype: Optional[torch.dtype] + ) -> _storage_type: + # Convert a tensor into an equivalent storage + # for compatibility with a TensorDeserializer + if dtype is None: + # Oddly, PyTorch expects a storage serialized as an UntypedStorage + # to be loaded as a TypedStorage with the torch.uint8 type. + dtype = torch.uint8 + return torch.TypedStorage( + wrap_storage=tensor.untyped_storage(), dtype=dtype, _internal=True + ) + + def __get_storage(self, idx: int, dtype: Optional[torch.dtype]): + # This will load all tensors the first time a "TensorizerPickler" + # persistent_id is encountered, indicating that this was a file + # created by a _TensorizerPickler. Deferring it to this point + # will avoid trying to engage the load logic on .pt files + # that were NOT created by a _TensorizerPickler, where there + # is probably no corresponding .tensors file anyway, where trying + # to load that would fail. + if self.__tensors is None: + self.__load_tensors() + tensor_view = self.__tensors[idx] + return self.__tensor_to_storage(tensor_view, dtype) + + @property + def __super_load(self) -> Callable[[Any], Any]: + if self.__cached_super_load is not None: + return self.__cached_super_load + super_load = super().persistent_load + super_load_func = getattr(super_load, "__func__", super_load) + # Evil Python behaviour can make the super method equal this method + # prior to Python 3.13, so check for that to avoid accidental recursion. + # _is_load_wrapper is set on dynamically-created wrappers + # that ultimately recurse back to this function; avoid those too. + if super_load_func == _TensorizerUnpickler.persistent_load or getattr( + super_load_func, "_is_load_wrapper", False + ): + # To avoid recursing forever, just raise the + # default error from pickle.Unpickler instead + self.__cached_super_load = self.__fallback_super_load + else: + # Will probably just throw an error, + # but could redirect to a sibling class + self.__cached_super_load = super_load + return self.__cached_super_load + + @staticmethod + def __fallback_super_load(_pid): + raise pickle.UnpicklingError("unsupported persistent id encountered") + + def persistent_load(self, pid): + if ( + self.__has_tensors + and isinstance(pid, tuple) + and pid[0] == "TensorizerPickler" + and len(pid) >= 3 + ): + version = pid[1] + if version > 0: + raise pickle.UnpicklingError( + f"Unsupported TensorizerPickler data version ({version:d})" + ) + object_type = pid[2] + if object_type == "storage": + idx, dtype = pid[3:] + return self.__get_storage(idx, dtype) + else: + raise pickle.UnpicklingError( + f"Unsupported TensorizerPickler object type ({object_type})" + ) + else: + return self.__super_load(pid) + + @staticmethod + def __wrap_persistent_load(persistent_load_func: callable): + + @functools.wraps(persistent_load_func) + def _persistent_load(self, pid): + try: + if self.__class__ is _TensorizerUnpickler: + # For instances of this class, call this class's method + return self.__class__.persistent_load(self, pid) + else: + # For subclasses, defer to the super method + return super(self.__class__, self).persistent_load(pid) + except pickle.UnpicklingError: + pass + # This is being set on an instance, not the class, + # so this wouldn't expect to be passed self as well, + # as it is not an unbound method here + return persistent_load_func(pid) + + return _persistent_load + + def __setattr__(self, key, value): + if key == "persistent_load": + # If this method is being overridden dynamically, modify it + # to defer to the persistent_load method from this class first + wrapped_func = self.__wrap_persistent_load(value) + # Mark this as a wrapper for recursion detection later on + wrapped_func._is_load_wrapper = True + value = types.MethodType(wrapped_func, self) + # Necessary witchcraft prior to Python 3.13: + # pickle.Unpickler may internally cache persistent_load functions, + # and it would normally update the cached value using a PyGetSetDef + # descriptor, but having a class in the inheritance hierarchy + # that defines persistent_load as a non-descriptor prevents + # attribute updates from reaching that descriptor's set method, + # so the cached value that the unpickler actually uses isn't + # properly updated, even though the Python object shows it as being + # updated. We can force this update to propagate to that descriptor + # by manipulating it directly. + if ( + pickle.Unpickler in self.__class__.__mro__ + and inspect.isgetsetdescriptor(pickle.Unpickler.persistent_load) + ): + pickle.Unpickler.persistent_load.__set__(self, value) + super().__setattr__(key, value) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + if "persistent_load" in cls.__dict__: + cls.persistent_load = cls.__wrap_persistent_load( + cls.persistent_load + ) + + +def _pickle_attr(name): + return getattr(pickle, name) + + +_tensorizer_pickle = types.ModuleType("tensorizer_pickle") +_tensorizer_pickle.__getattr__ = _pickle_attr +_tensorizer_pickle.Pickler = _TensorizerPickler +_tensorizer_pickle.Unpickler = _TensorizerUnpickler + + +_ORIG_TORCH_SAVE: Final[callable] = torch.save +_ORIG_TORCH_LOAD: Final[callable] = torch.load + + +def _infer_tensor_ext_name(f: torch.types.FileLike): + if isinstance(f, io.BytesIO): + logger.warning( + "Cannot infer .tensors location from io.BytesIO;" + " not using tensorizer backend" + " (set the file_obj parameter to choose a location instead)" + ) + return None + filename: str + try: + filename = os.fsdecode(f) + except TypeError: + if hasattr(f, "name"): + filename = os.fsdecode(f.name) + else: + raise + return filename + ".tensors" + + +@contextlib.contextmanager +def _contextual_torch_filename( + f: torch.types.FileLike, + filename_ctx_var: ContextVar[Optional[_wrapper_file_obj_type]], +): + if filename_ctx_var.get() is None: + token = filename_ctx_var.set(_infer_tensor_ext_name(f)) + elif callable(filename_callback := filename_ctx_var.get()): + token = filename_ctx_var.set(filename_callback(f)) + else: + token = None + try: + yield + finally: + if token is not None: + filename_ctx_var.reset(token) + + +_save_wrapper_active: ContextVar[bool] = ContextVar( + "_save_wrapper_active", default=False +) +_save_wrapper_active_count: int = 0 +_save_wrapper_active_mutex: threading.Lock = threading.Lock() +_save_wrapper_wrapped: Optional[callable] = None +_save_wrapper_save_func: ContextVar[Optional[_save_func_type]] = ContextVar( + "_save_wrapper_save_func", default=None +) + +_load_wrapper_active: ContextVar[bool] = ContextVar( + "_load_wrapper_active", default=False +) +_load_wrapper_active_count: int = 0 +_load_wrapper_active_mutex: threading.Lock = threading.Lock() +_load_wrapper_wrapped: Optional[callable] = None +_load_wrapper_load_func: ContextVar[Optional[_load_func_type]] = ContextVar( + "_load_wrapper_load_func", default=None +) + +_suppress_weights_only: ContextVar[bool] = ContextVar( + "_suppress_weights_only", default=False +) + + +@functools.wraps(_ORIG_TORCH_SAVE) +def _save_wrapper( + obj: object, + f: torch.types.FileLike, + pickle_module: Any = pickle, + *args, + **kwargs, +): + if not _save_wrapper_active.get(): + return _ORIG_TORCH_SAVE(obj, f, pickle_module, *args, **kwargs) + if pickle_module is not None and pickle_module is not pickle: + raise ValueError( + "Tensorizer-based torch serialization is incompatible with" + " using a pickle_module other than the default" + ) + with _contextual_torch_filename(f, _tensorizer_saving_filename): + return _ORIG_TORCH_SAVE( + obj, f, *args, pickle_module=_tensorizer_pickle, **kwargs + ) + + +# This signature quietly changed in torch 1.13.0 to default to None, +# but the documentation wasn't updated to reflect that. +_LOAD_WRAPPER_DEFAULT_MODULE: Any = ( + pickle if torch.__version__ < (1, 13, 0) else None +) + + +@functools.wraps(_ORIG_TORCH_LOAD) +def _load_wrapper( + f: torch.types.FileLike, + map_location: torch.serialization.MAP_LOCATION = None, + pickle_module: Any = _LOAD_WRAPPER_DEFAULT_MODULE, + *args, + weights_only: Optional[bool] = None, + **kwargs, +): + if not _load_wrapper_active.get(): + return _ORIG_TORCH_LOAD( + f, + map_location, + pickle_module, + *args, + weights_only=weights_only, + **kwargs, + ) + if pickle_module is not None and pickle_module is not pickle: + raise ValueError( + "Tensorizer-based torch serialization is incompatible with" + " using a pickle_module other than the default" + ) + with _contextual_torch_filename(f, _tensorizer_loading_filename): + if _suppress_weights_only.get(): + weights_only = False + return _ORIG_TORCH_LOAD( + f, + map_location, + pickle_module=_tensorizer_pickle, + *args, + weights_only=weights_only, + **kwargs, + ) + + +@contextlib.contextmanager +def tensorizer_saving( + file_obj: Optional[_wrapper_file_obj_type] = None, + *, + save_func: Optional[_save_func_type] = None, + **kwargs, +): + """ + Context manager that modifies calls to ``torch.save`` to use tensorizer + as a backend for the serialization of tensors and tensor storages. + + Tensors are saved in a sidecar file separate from the ``.pt`` file created + by ``torch.save``. To load them again, use the `tensorizer_loading` + context manager paired with ``torch.load``. + + Notes: + This context manager is thread-safe and async-safe. Other threads or + coroutines executing concurrently while this context is active will not + be modified. + + Args: + file_obj: The file or file-like object in which to save tensor data, + separate from the one passed to ``torch.save`` for saving metadata. + This can be any type accepted by a `TensorSerializer`, or a callable + that dynamically generates the file path or file object based on + the file path or file-like object ``f`` passed to the ``torch.save`` + call. When using a callable, it should take a single argument of + the type ``torch.types.FileLike``, and output a type accepted + by a `TensorSerializer`. The default behaviour is to use a callable + that appends ``".tensors"`` to any filename passed as ``f``. + If a provided callable returns ``None``, tensorizer deserialization + is not used. + save_func: An optional callable with the signature + ``save_func(file_obj, tensors: Iterable[Tensor], kwargs: dict)`` + that may be used to override the default saving logic for tensors. + `file_obj` and `kwargs` correspond to the ones passed to this + function. This may be used, for instance, to make serialization + asynchronous by writing a `save_func` that serializes in + a background thread or process. + kwargs: Further keyword arguments to pass to the `TensorSerializer` + object used to save tensor data. + """ + global _save_wrapper_active_count, _save_wrapper_wrapped + active_token = _save_wrapper_active.set(True) + kwargs_token = _tensorizer_serializer_kwargs.set(kwargs) + filename_token = _tensorizer_saving_filename.set(file_obj) + save_func_token = _save_wrapper_save_func.set(save_func) + with _save_wrapper_active_mutex: + _save_wrapper_active_count += 1 + if _save_wrapper_active_count == 1: + assert _save_wrapper_wrapped is None + torch.save, _save_wrapper_wrapped = _save_wrapper, torch.save + try: + yield + finally: + with _save_wrapper_active_mutex: + _save_wrapper_active_count -= 1 + if _save_wrapper_active_count == 0: + assert _save_wrapper_wrapped is not None + torch.save = _save_wrapper_wrapped + _save_wrapper_wrapped = None + _save_wrapper_save_func.reset(save_func_token) + _tensorizer_saving_filename.reset(filename_token) + _tensorizer_serializer_kwargs.reset(kwargs_token) + _save_wrapper_active.reset(active_token) + + +@contextlib.contextmanager +def tensorizer_loading( + file_obj: Optional[_wrapper_file_obj_type] = None, + *, + load_func: Optional[_load_func_type] = None, + suppress_weights_only: bool = False, + **kwargs, +): + """ + Context manager that modifies calls to ``torch.load`` to use tensorizer + as a backend for the deserialization of tensors and tensor storages. + This is only valid to use when deserializing files that were serialized + using the corresponding `tensorizer_saving` context manager paired with + ``torch.save``. + + Tensors are saved in a sidecar file separate from the ``.pt`` file created + by ``torch.save``. Both must be present at deserialization time. + + Notes: + This context manager is thread-safe and async-safe. Other threads or + coroutines executing concurrently while this context is active will not + be modified. + + Args: + file_obj: The file or file-like object from which to load tensor data, + separate from the one passed to ``torch.load`` for loading metadata. + This can be any type accepted by a `TensorDeserializer`, or a + callable that dynamically generates the file path or file object + based on the file path or file-like object `f` passed to the + ``torch.load`` call. When using a callable, it should take a single + argument of the type ``torch.types.FileLike``, and output a type + accepted by a `TensorDeserializer`. The default behaviour is to use + a callable that appends ``".tensors"`` to any filename passed as + ``f``. If a provided callable returns ``None``, tensorizer + serialization is not used. + load_func: An optional callable with the signature + ``load_func(file_obj, kwargs: dict) -> Iterable[Tensor]`` + that may be used to override the default loading logic for tensors. + `file_obj` and `kwargs` correspond to the ones passed to this + function. + suppress_weights_only: If set to ``True``, replace ``weights_only=True`` + with ``weights_only=False`` in calls to ``torch.load`` within this + context. Using ``torch.load`` with tensorizer as a backend is + incompatible with ``weights_only=True`` because ``torch`` counts it + using a custom ``pickle_module`` as being a non-weights-only load, + even though tensorizer only loads weights in practice. + kwargs: Further keyword arguments to pass to the `TensorDeserializer` + object used to load tensor data. + """ + global _load_wrapper_active_count, _load_wrapper_wrapped + active_token = _load_wrapper_active.set(True) + weights_token = _suppress_weights_only.set(suppress_weights_only) + kwargs_token = _tensorizer_deserializer_kwargs.set(kwargs) + filename_token = _tensorizer_loading_filename.set(file_obj) + load_func_token = _load_wrapper_load_func.set(load_func) + with _load_wrapper_active_mutex: + _load_wrapper_active_count += 1 + if _load_wrapper_active_count == 1: + assert _load_wrapper_wrapped is None + torch.load, _load_wrapper_wrapped = _load_wrapper, torch.load + try: + yield + finally: + with _load_wrapper_active_mutex: + _load_wrapper_active_count -= 1 + if _load_wrapper_active_count == 0: + assert _load_wrapper_wrapped is not None + torch.load = _load_wrapper_wrapped + _load_wrapper_wrapped = None + _load_wrapper_load_func.reset(load_func_token) + _tensorizer_loading_filename.reset(filename_token) + _tensorizer_deserializer_kwargs.reset(kwargs_token) + _suppress_weights_only.reset(weights_token) + _load_wrapper_active.reset(active_token) diff --git a/tests/test_torch_compat.py b/tests/test_torch_compat.py new file mode 100644 index 0000000..62d343b --- /dev/null +++ b/tests/test_torch_compat.py @@ -0,0 +1,648 @@ +import concurrent.futures +import contextlib +import inspect +import io +import itertools +import pickle +import tempfile +import threading +import typing +import unittest +from functools import partial +from pathlib import Path +from typing import ClassVar, Final, Optional, Sequence, Tuple + +import torch +import transformers + +import tensorizer +import tensorizer.torch_compat as torch_compat +from tensorizer.serialization import DecryptionParams, EncryptionParams +from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving + +_ORIG_TORCH_SAVE: Final[callable] = torch.save +_ORIG_TORCH_LOAD: Final[callable] = torch.load + +fastest_device: Final[torch.device] = ( + torch.device("cuda", 0) + if torch.cuda.is_available() + else torch.device("cpu") +) + + +class TestTorchCompat(unittest.TestCase): + MODEL_REF: ClassVar[str] = "EleutherAI/gpt-neo-125M" + model: ClassVar[torch.nn.Module] + orig_tensors: ClassVar[Sequence[Tuple[str, torch.Tensor]]] + tmp_dir: ClassVar[tempfile.TemporaryDirectory] + tmp_dir_path: ClassVar[Path] + pt_path: ClassVar[Path] + tensors_path: ClassVar[Path] + reference_save_path: ClassVar[Path] + reference_save_size: ClassVar[int] + + @classmethod + def load_reference_model(cls, dtype, device=torch.device("cpu")): + with device: + return transformers.AutoModelForCausalLM.from_pretrained( + cls.MODEL_REF + ).to(dtype) + + @classmethod + def setUpClass(cls): + cls.model = cls.load_reference_model(dtype=torch.float16) + cls.addClassCleanup(delattr, cls, "model") + cls.model.eval() + cls.orig_tensors: Sequence[Tuple[str, torch.Tensor]] = ( + cls.extract_tensors(cls.model) + ) + cls.addClassCleanup(delattr, cls, "orig_tensors") + + cls.tmp_dir = tempfile.TemporaryDirectory(prefix="test_torch_compat") + cls.tmp_dir.__enter__() + cls.addClassCleanup(cls.tmp_dir.__exit__, None, None, None) + + cls.tmp_dir_path = Path(cls.tmp_dir.name) + cls.pt_path = cls.tmp_dir_path / "test.pt" + cls.tensors_path = cls.tmp_dir_path / "test.pt.tensors" + + # For use as a reference + cls.reference_save_path = cls.tmp_dir_path / "reference.pt" + torch.save(cls.model, cls.reference_save_path) + cls.reference_save_size = cls.reference_save_path.stat().st_size + + @staticmethod + def extract_tensors( + model: torch.nn.Module, + ) -> Sequence[Tuple[str, torch.Tensor]]: + return list(model.state_dict().items()) + + def setUp(self): + self.assertFalse(self.pt_path.exists()) + self.assertFalse(self.tensors_path.exists()) + + def tearDown(self): + self.pt_path.unlink(missing_ok=True) + self.tensors_path.unlink(missing_ok=True) + + def check_model( + self, + loaded_model: torch.nn.Module, + expect_device: Optional[torch.device] = None, + reference: Optional[Sequence[Tuple[str, torch.Tensor]]] = None, + ): + loaded_model.eval() + loaded_tensors = self.extract_tensors(loaded_model) + orig_tensors = self.orig_tensors if reference is None else reference + orig_keys = [k for k, _ in orig_tensors] + loaded_keys = [k for k, _ in loaded_tensors] + self.assertListEqual(orig_keys, loaded_keys) + has_tensors: bool = False + _i = 0 + for (name, tensor), (loaded_name, loaded_tensor) in zip( + orig_tensors, loaded_tensors + ): + has_tensors = True + _i += 1 + self.assertEqual(name, loaded_name) + self.assertEqual(tensor.size(), loaded_tensor.size()) + self.assertEqual(tensor.stride(), loaded_tensor.stride()) + self.assertEqual(tensor.dtype, loaded_tensor.dtype) + if expect_device is not None: + self.assertEqual(loaded_tensor.device, expect_device) + if loaded_tensor.device != tensor.device: + loaded_tensor = loaded_tensor.to(tensor.device) + self.assertTrue(torch.equal(tensor, loaded_tensor)) + self.assertTrue(has_tensors) + + def check_save_load_signatures( + self, save_func: callable, load_func: callable + ): + # Ensure that the function signatures of torch.save and torch.load + # match what the wrapper code expects them to be. + empty = inspect.Parameter.empty + expected_save_signature = ( + ("obj", empty), + ("f", empty), + ("pickle_module", pickle), + ) + expected_load_signature = ( + ("f", empty), + ("map_location", None), + ("pickle_module", torch_compat._LOAD_WRAPPER_DEFAULT_MODULE), + ) + for func, expected in ( + (save_func, expected_save_signature), + (load_func, expected_load_signature), + ): + params = inspect.signature( + func, follow_wrapped=False + ).parameters.values() + for param, (name, default) in zip(params, expected): + self.assertEqual(param.name, name) + self.assertEqual(param.default, default) + + def test_signatures(self): + with self.subTest("Testing torch signatures"): + self.assertIs(torch.save, _ORIG_TORCH_SAVE) + self.assertIs(torch.load, _ORIG_TORCH_LOAD) + self.check_save_load_signatures(torch.save, torch.load) + + with self.subTest( + "Testing wrapper signatures" + ), tensorizer_saving(), tensorizer_loading(): + self.assertIsNot(torch.save, _ORIG_TORCH_SAVE) + self.assertIsNot(torch.load, _ORIG_TORCH_LOAD) + self.check_save_load_signatures(torch.save, torch.load) + + def test_torch_load(self): + # Sanity check + with torch.device("cpu"): + self.assertIs(torch.load, _ORIG_TORCH_LOAD) + loaded_model: torch.nn.Module = torch.load( + self.reference_save_path, weights_only=False + ) + self.assertFalse( + self.reference_save_path.with_suffix(".pt.tensors").exists() + ) + self.check_model(loaded_model, torch.device("cpu")) + + def test_save_load(self): + with tensorizer_saving(): + torch.save(self.model, self.pt_path) + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(self.tensors_path.is_file()) + + with self.subTest("Testing file sizes"): + pt_size: int = self.pt_path.stat().st_size + tensors_size: int = self.tensors_path.stat().st_size + self.assertLess(pt_size, self.reference_save_size) + self.assertLess(pt_size, 1 << 20) # Should be less than 1 MiB + self.assertGreater(tensors_size, pt_size) + self.assertGreater(tensors_size, 20 << 20) # More than 20 MiB + self.assertLess(tensors_size, int(self.reference_save_size * 1.8)) + + with self.subTest("Testing loading"): + with tensorizer_loading(device=fastest_device): + loaded_model = torch.load(self.pt_path, weights_only=False) + loaded_model.eval() + self.check_model(loaded_model, torch.device(fastest_device)) + # Check that it can process a forward pass + loaded_model(torch.tensor((1000,), device=fastest_device)) + + # def test_save_load_s3(self): + # pass + + def test_save_load_args(self): + encryption = EncryptionParams.random() + decryption = DecryptionParams.from_key(encryption.key) + tensors_path = self.tensors_path.with_suffix(".tensors.test") + self.addCleanup(tensors_path.unlink, missing_ok=True) + with tensorizer_saving(tensors_path, encryption=encryption): + torch.save(self.model, self.pt_path) + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(tensors_path.is_file()) + + with self.subTest("Testing loading"): + with tensorizer_loading( + tensors_path, device=fastest_device, encryption=decryption + ): + loaded_model = torch.load(self.pt_path, weights_only=False) + loaded_model.eval() + self.check_model(loaded_model, fastest_device) + del loaded_model + + with self.subTest("Testing invalid loading"), self.assertRaises( + tensorizer.CryptographyError + ), tensorizer_loading(tensors_path, device=fastest_device): + torch.load(self.pt_path, weights_only=False) + + def test_save_load_fp8_torch(self): + dtype = torch.float8_e4m3fn + model = self.load_reference_model(dtype=dtype, device=fastest_device) + with tensorizer_saving(): + torch.save(model, self.pt_path) + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(self.tensors_path.is_file()) + + with self.subTest("Testing loading"): + with tensorizer_loading(device=fastest_device): + loaded_model = torch.load(self.pt_path, weights_only=False) + loaded_model.eval() + dtypes = { + tensor.dtype for _, tensor in self.extract_tensors(loaded_model) + } + self.assertIn(dtype, dtypes) + self.assertNotIn(torch.float16, dtypes) + self.check_model( + loaded_model, + fastest_device, + reference=self.extract_tensors(model), + ) + + def test_thread_safety(self): + start: threading.Barrier = threading.Barrier(parties=2) + finish: threading.Barrier = threading.Barrier(parties=2) + model_1: torch.nn.Module = self.model + model_2: torch.nn.Module = self.load_reference_model(torch.float16) + + def _save_load_tensorizer( + model: torch.nn.Module, + pt_path: Path, + save_kwargs: dict, + load_kwargs: dict, + ) -> torch.nn.Module: + with tensorizer_saving(**save_kwargs): + start.wait(timeout=10) + start.reset() + torch.save(model, pt_path) + finish.wait(timeout=10) + finish.reset() + + with tensorizer_loading(**load_kwargs): + start.wait(timeout=10) + start.reset() + try: + return torch.load(pt_path) + finally: + finish.wait(timeout=10) + finish.reset() + + def _save_load_torch( + model: torch.nn.Module, pt_path: Path + ) -> torch.nn.Module: + start.wait(timeout=10) + torch.save(model, pt_path) + finish.wait(timeout=10) + + start.wait(timeout=10) + try: + return torch.load(pt_path, weights_only=False) + finally: + finish.wait(timeout=10) + + pt_path_1 = self.pt_path + tensors_path_1 = self.tensors_path + pt_path_2 = self.tmp_dir_path / "test-2.pt" + tensors_path_2 = self.tmp_dir_path / "test-2.pt.tensors" + paths = (pt_path_1, tensors_path_1, pt_path_2, tensors_path_2) + for path in paths: + self.addCleanup(path.unlink, missing_ok=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + with self.subTest("One thread active, one thread inactive"): + f1 = pool.submit( + _save_load_tensorizer, model_1, pt_path_1, {}, {} + ) + f2 = pool.submit(_save_load_torch, model_2, pt_path_2) + m1 = f1.result() + m2 = f2.result() + self.check_model(m1) + self.check_model(m2) + self.assertTrue(pt_path_1.is_file()) + self.assertTrue(tensors_path_1.is_file()) + self.assertTrue(pt_path_2.is_file()) + self.assertFalse(tensors_path_2.exists()) + for path in paths: + path.unlink(missing_ok=True) + + start.reset() + finish.reset() + + with self.subTest("Two threads with different active contexts"): + encryption_1 = EncryptionParams.random() + decryption_1 = DecryptionParams.from_key(encryption_1.key) + encryption_2 = EncryptionParams.random() + decryption_2 = DecryptionParams.from_key(encryption_2.key) + self.assertNotEqual(encryption_1.key, encryption_2.key) + f1 = pool.submit( + _save_load_tensorizer, + model_1, + pt_path_1, + {"encryption": encryption_1}, + {"encryption": decryption_1}, + ) + f2 = pool.submit( + _save_load_tensorizer, + model_2, + pt_path_2, + {"encryption": encryption_2}, + {"encryption": decryption_2}, + ) + m1 = f1.result() + m2 = f2.result() + self.check_model(m1) + self.check_model(m2) + for path in paths: + self.assertTrue(path.is_file()) + + def test_shared_storage(self): + a = torch.tensor((1, 2, 3), dtype=torch.long) + b = torch.tensor((4, 5, 6), dtype=torch.long) + c = torch.tensor((), dtype=torch.long) + # noinspection PyTypeChecker + c.set_(a) + self.assertTrue(c.is_set_to(a)) + self.assertEqual(a.data_ptr(), c.data_ptr()) + self.assertTrue(torch.equal(a, c)) + d = c.view(dtype=torch.float64) + self.assertTrue(d.is_set_to(a)) + tensors = [a, b, c, d] + with tensorizer_saving(): + torch.save(tensors, self.pt_path) + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(self.tensors_path.is_file()) + + with tensorizer_loading(device="cpu"): + _a, _b, _c, _d = torch.load(self.pt_path, weights_only=False) + for orig, loaded in ((a, _a), (b, _b), (c, _c), (d, _d)): + self.assertTrue(torch.equal(orig, loaded)) + self.assertEqual(orig.dtype, loaded.dtype) + self.assertTrue(torch.equal(_a, _c)) + self.assertEqual(_a.data_ptr(), _c.data_ptr()) + self.assertTrue(_c.is_set_to(_a)) + self.assertTrue(_d.is_set_to(_a)) + + def test_suppress_weights_only(self): + tensors = list(torch.arange(16).view((4, 4))) + + with tensorizer_saving(): + torch.save(tensors, self.pt_path) + with self.assertRaisesRegex( + RuntimeError, + "Can not safely load weights" + " when explicit pickle_module is specified", + ), tensorizer_loading(device="cpu"): + torch.load(self.pt_path, weights_only=True) + with tensorizer_loading(device="cpu", suppress_weights_only=True): + loaded_tensors = torch.load(self.pt_path, weights_only=True) + + self.assertTrue( + torch.equal(torch.stack(tensors), torch.stack(loaded_tensors)) + ) + + def test_meta_tensors(self): + t1 = torch.tensor((1, 2, 3, 4), dtype=torch.long) + t2 = t1.to(device="meta") + + with tensorizer_saving(): + torch.save([t1, t2], self.pt_path) + + with tensorizer_loading(device="cpu"): + loaded_t1, loaded_t2 = torch.load(self.pt_path) + + self.assertTrue(torch.equal(t1, loaded_t1)) + self.assertTrue(loaded_t2.is_meta) + self.assertEqual(t2.dtype, loaded_t2.dtype) + self.assertEqual(t2.size(), loaded_t2.size()) + self.assertEqual(t2.stride(), loaded_t2.stride()) + + def test_name_callback(self): + dynamic_tensors_path: Path = self.pt_path.with_suffix( + ".pt.tensors.dynamic" + ) + + def path_callback(f: torch.types.FileLike) -> io.BufferedIOBase: + # Test with an exotic function that returns a pre-opened + # stream dynamically, based on the input file's name + _path = Path(f).with_suffix(".pt.tensors.dynamic") + if not _path.exists(): + self.addCleanup(_path.unlink, missing_ok=True) + file_obj = _path.open("rb" if _path.exists() else "wb+") + self.addCleanup(file_obj.close) + return typing.cast(io.BufferedIOBase, file_obj) + + with tensorizer_saving(path_callback): + torch.save(self.model, self.pt_path) + + self.assertTrue(self.pt_path.is_file()) + self.assertFalse(self.tensors_path.exists()) + self.assertTrue(dynamic_tensors_path.is_file()) + + with tensorizer_loading(path_callback, device="cpu"): + loaded_model = torch.load(self.pt_path) + + self.check_model(loaded_model) + + def test_nested_contexts(self): + sd = { + f"layer.{i:d}": torch.randn( + (16, 16), device="cpu", dtype=torch.float32 + ) + for i in range(4) + } + keys = tuple(sd.keys()) + + def check_sd(_sd): + self.assertTupleEqual(keys, tuple(_sd.keys())) + for name, tensor in sd.items(): + self.assertTrue(torch.equal(tensor, _sd[name])) + + # These produce reusable callables with frozen args + cpu_loading = partial(tensorizer_loading, device="cpu") + saving = partial(partial, tensorizer_saving) + loading = partial(partial, cpu_loading) + + def permuted(context1, context2): + @contextlib.contextmanager + def _ctx(): + with self.subTest(f"{name1} + {name2}"), ctx1(), ctx2(): + yield + + for (name1, ctx1), (name2, ctx2) in itertools.permutations( + (context1, context2) + ): + yield _ctx + + for ctx in permuted( + ("tensorizer_saving", saving()), + ("tensorizer_loading", loading()), + ): + with ctx(): + torch.save(sd, self.pt_path) + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(self.tensors_path.is_file()) + check_sd(torch.load(self.pt_path)) + self.pt_path.unlink(missing_ok=True) + self.tensors_path.unlink(missing_ok=True) + + alt_tensors_path = self.tmp_dir_path / "test-2.pt.tensors" + self.addCleanup(alt_tensors_path.unlink, missing_ok=True) + + def cleanup() -> None: + self.pt_path.unlink(missing_ok=True) + self.tensors_path.unlink(missing_ok=True) + alt_tensors_path.unlink(missing_ok=True) + + def check_saved_primary() -> None: + self.assertTrue(self.pt_path.is_file()) + self.assertTrue(self.tensors_path.is_file()) + self.assertFalse(alt_tensors_path.exists()) + + def check_saved_alt() -> None: + self.assertTrue(self.pt_path.is_file()) + self.assertFalse(self.tensors_path.exists()) + self.assertTrue(alt_tensors_path.is_file()) + + # + # Test mixing tensorizer_saving and tensorizer_loading together + # + + # Try saving to an alternate path but not loading from it + for ctx in permuted( + ("tensorizer_saving(path)", saving(alt_tensors_path)), + ("tensorizer_loading", loading()), + ): + with ctx(): + torch.save(sd, self.pt_path) + check_saved_alt() + with self.assertRaises(OSError): + torch.load(self.pt_path) + cleanup() + + # Try loading from an alternate path but not saving to it + for ctx in permuted( + ("tensorizer_saving", saving()), + ("tensorizer_loading(path)", loading(alt_tensors_path)), + ): + with ctx(): + torch.save(sd, self.pt_path) + check_saved_primary() + with self.assertRaises(OSError): + torch.load(self.pt_path) + cleanup() + + # Try both saving to and loading from an alternate path + for ctx in permuted( + ("tensorizer_saving(path)", saving(alt_tensors_path)), + ("tensorizer_loading(path)", loading(alt_tensors_path)), + ): + with ctx(): + torch.save(sd, self.pt_path) + check_saved_alt() + check_sd(torch.load(self.pt_path)) + cleanup() + + # + # Test nesting multiple levels of the same type of context manager + # The most recent context should take precedence + # + + # Nested saving context managers + for save_name, default_save in ( + ("tensorizer_saving", saving()), + ("tensorizer_saving(default)", saving(self.tensors_path)), + ): + with self.subTest(f"{save_name} + tensorizer_saving(path)"): + with default_save(), tensorizer_saving(alt_tensors_path): + torch.save(sd, self.pt_path) + check_saved_alt() + with cpu_loading(alt_tensors_path): + check_sd(torch.load(self.pt_path)) + cleanup() + + with self.subTest(f"tensorizer_saving(path) + {save_name}"): + with tensorizer_saving(alt_tensors_path), default_save(): + torch.save(sd, self.pt_path) + check_saved_primary() + with cpu_loading(): + check_sd(torch.load(self.pt_path)) + cleanup() + + # Make sure an outer context is restored + # correctly after leaving an inner context + with self.subTest(f"tensorizer_saving(path) after {save_name}"): + with tensorizer_saving(alt_tensors_path): + with default_save(): + # This should temporarily change the context, + # but bring it back once the block is over. + pass + torch.save(sd, self.pt_path) + check_saved_alt() + with cpu_loading(alt_tensors_path): + check_sd(torch.load(self.pt_path)) + cleanup() + + with self.subTest(f"{save_name} after tensorizer_saving(path)"): + with default_save(): + with tensorizer_saving(alt_tensors_path): + # This should temporarily change the context, + # but bring it back once the block is over. + pass + torch.save(sd, self.pt_path) + check_saved_primary() + with cpu_loading(): + check_sd(torch.load(self.pt_path)) + cleanup() + + # Nested loading context managers + for load_name, default_load in ( + ("tensorizer_loading", loading()), + ("tensorizer_loading(default)", loading(self.tensors_path)), + ): + with self.subTest(f"{load_name} + tensorizer_loading(path)"): + with tensorizer_saving(alt_tensors_path): + torch.save(sd, self.pt_path) + check_saved_alt() + with default_load(), cpu_loading(alt_tensors_path): + check_sd(torch.load(self.pt_path)) + cleanup() + + with self.subTest(f"tensorizer_loading(path) + {load_name}"): + with tensorizer_saving(): + torch.save(sd, self.pt_path) + check_saved_primary() + with cpu_loading(alt_tensors_path), default_load(): + check_sd(torch.load(self.pt_path)) + cleanup() + + # Make sure an outer context is restored + # correctly after leaving an inner context + with self.subTest(f"tensorizer_loading(path) after {save_name}"): + with tensorizer_saving(alt_tensors_path): + torch.save(sd, self.pt_path) + check_saved_alt() + with cpu_loading(alt_tensors_path): + with default_load(): + # This should temporarily change the context, + # but bring it back once the block is over. + pass + check_sd(torch.load(self.pt_path)) + cleanup() + + with self.subTest(f"{save_name} after tensorizer_loading(path)"): + with tensorizer_saving(): + torch.save(sd, self.pt_path) + check_saved_primary() + with cpu_loading(): + with cpu_loading(alt_tensors_path): + # This should temporarily change the context, + # but bring it back once the block is over. + pass + check_sd(torch.load(self.pt_path)) + cleanup() + + def test_save_load_without_tensors(self): + original = [1, "2", 3.0, torch.device("meta")] + + with tensorizer_saving(): + torch.save(original, self.pt_path) + + self.assertTrue(self.pt_path.is_file()) + self.assertFalse(self.tensors_path.exists()) + + with tensorizer_loading(): + loaded = torch.load(self.pt_path) + + self.assertListEqual(original, loaded) + + def test_load_with_regular_file(self): + torch.save(self.model, self.pt_path) + + self.assertTrue(self.pt_path.is_file()) + self.assertFalse(self.tensors_path.exists()) + + with tensorizer_loading(): + loaded_model = torch.load(self.pt_path) + + self.check_model(loaded_model)