From f62fc20bbf3cb66d8b589c1085e782d376575fc9 Mon Sep 17 00:00:00 2001 From: Stuart Abercrombie Date: Tue, 5 Aug 2025 16:13:40 +0000 Subject: [PATCH] Use the older FileLike definition where necessary --- tensorizer/torch_compat.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tensorizer/torch_compat.py b/tensorizer/torch_compat.py index 8345dc4..9ff463a 100644 --- a/tensorizer/torch_compat.py +++ b/tensorizer/torch_compat.py @@ -55,6 +55,12 @@ logger = logging.getLogger(__name__) +if hasattr(torch.serialization, "FILE_LIKE"): + # Pre torch 2.7.1 + FileLike = torch.serialization.FILE_LIKE +else: + FileLike = torch.types.FileLike + _tensorizer_file_obj_type: "typing.TypeAlias" = Union[ io.BufferedIOBase, io.RawIOBase, @@ -67,7 +73,7 @@ _wrapper_file_obj_type: "typing.TypeAlias" = Union[ _tensorizer_file_obj_type, - Callable[[torch.types.FileLike], _tensorizer_file_obj_type], + Callable[[FileLike], _tensorizer_file_obj_type], ] _save_func_type: "typing.TypeAlias" = Callable[ @@ -397,7 +403,7 @@ def _pickle_attr(name): _ORIG_TORCH_LOAD: Final[callable] = torch.load -def _infer_tensor_ext_name(f: torch.types.FileLike): +def _infer_tensor_ext_name(f: FileLike): if isinstance(f, io.BytesIO): logger.warning( "Cannot infer .tensors location from io.BytesIO;" @@ -418,7 +424,7 @@ def _infer_tensor_ext_name(f: torch.types.FileLike): @contextlib.contextmanager def _contextual_torch_filename( - f: torch.types.FileLike, + f: FileLike, filename_ctx_var: ContextVar[Optional[_wrapper_file_obj_type]], ): if filename_ctx_var.get() is None: @@ -462,7 +468,7 @@ def _contextual_torch_filename( @functools.wraps(_ORIG_TORCH_SAVE) def _save_wrapper( obj: object, - f: torch.types.FileLike, + f: FileLike, pickle_module: Any = pickle, *args, **kwargs, @@ -489,7 +495,7 @@ def _save_wrapper( @functools.wraps(_ORIG_TORCH_LOAD) def _load_wrapper( - f: torch.types.FileLike, + f: FileLike, map_location: torch.serialization.MAP_LOCATION = None, pickle_module: Any = _LOAD_WRAPPER_DEFAULT_MODULE, *args, @@ -550,7 +556,7 @@ def tensorizer_saving( that dynamically generates the file path or file object based on the file path or file-like object ``f`` passed to the ``torch.save`` call. When using a callable, it should take a single argument of - the type ``torch.types.FileLike``, and output a type accepted + the type ``FileLike``, and output a type accepted by a `TensorSerializer`. The default behaviour is to use a callable that appends ``".tensors"`` to any filename passed as ``f``. If a provided callable returns ``None``, tensorizer deserialization @@ -620,7 +626,7 @@ def tensorizer_loading( callable that dynamically generates the file path or file object based on the file path or file-like object `f` passed to the ``torch.load`` call. When using a callable, it should take a single - argument of the type ``torch.types.FileLike``, and output a type + argument of the type ``FileLike``, and output a type accepted by a `TensorDeserializer`. The default behaviour is to use a callable that appends ``".tensors"`` to any filename passed as ``f``. If a provided callable returns ``None``, tensorizer