Skip to content

Commit 34f578c

Browse files
authored
Merge pull request #194 from coreweave/eta/torch-compat
feat(torch_compat): Add `torch_compat` module
2 parents b7ffda8 + 21c6434 commit 34f578c

File tree

6 files changed

+1522
-2
lines changed

6 files changed

+1522
-2
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [Unreleased]
9+
10+
### Added
11+
12+
- `tensorizer.torch_compat` is a new module for using `tensorizer` as a backend
13+
for handling tensor data during standard `torch.save` and `torch.load` calls
14+
- To use `tensorizer` as a backend for `torch.save`,
15+
wrap the call in the `tensorizer_saving` context manager
16+
- The file created must then be loaded using `tensorizer_loading`
17+
- To use `tensorizer` as a backend for `torch.load`,
18+
wrap the call in the `tensorizer_loading` context manager
19+
- The file to load must have been created using `tensorizer_saving`
20+
821
## [2.10.1] - 2025-06-27
922

1023
### Fixed
@@ -472,6 +485,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
472485
- `get_gpu_name`
473486
- `no_init_or_tensor`
474487

488+
[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.10.1...HEAD
475489
[2.10.1]: https://github.com/coreweave/tensorizer/compare/v2.10.0...v2.10.1
476490
[2.10.0]: https://github.com/coreweave/tensorizer/compare/v2.9.3...v2.10.0
477491
[2.9.3]: https://github.com/coreweave/tensorizer/compare/v2.9.2...v2.9.3

README.md

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,197 @@ An example command line tool to add or remove encryption from existing
281281
serialized models is also available as
282282
[examples/encryption.py](examples/encrypt_existing.py).
283283

284+
## PyTorch Compatibility
285+
286+
`tensorizer`'s `TensorSerializer` and `TensorDeserializer` classes are designed
287+
to be able to replace the use of `torch.save` and `torch.load` in model saving
288+
and loading pipelines, however, they are not drop-in replacements. The API for
289+
serialization and deserialization with `tensorizer` offer more parameters to
290+
control performance, resource usage, and additional features like encryption,
291+
so they are invoked differently.
292+
For drop-in replacements, see the next section.
293+
294+
The examples below show example usages of
295+
`torch.save` and `torch.load`, and how they may be replaced with `tensorizer`
296+
serialization.
297+
298+
```py
299+
from tensorizer import TensorDeserializer, TensorSerializer
300+
import torch
301+
302+
model: torch.nn.Module = ...
303+
304+
# Saving with torch.save
305+
state_dict = model.state_dict()
306+
torch.save(state_dict, "model.pt")
307+
308+
# Loading with torch.load
309+
state_dict = torch.load("model.pt", map_location="cuda:0")
310+
model.load_state_dict(state_dict)
311+
312+
# Saving with tensorizer.TensorSerializer
313+
state_dict = model.state_dict()
314+
serializer = TensorSerializer("model.tensors")
315+
serializer.write_state_dict(state_dict)
316+
serializer.close()
317+
318+
# Loading with tensorizer.TensorDeserializer
319+
with TensorDeserializer("model.tensors", device="cuda:0") as state_dict:
320+
model.load_state_dict(state_dict)
321+
```
322+
323+
> [!NOTE]
324+
>
325+
> `TensorDeserializer` is a context manager because it supports lazy-loading,
326+
> where the context controls how long its source file will remain open to read
327+
> more tensors. This behaviour is optional and can be engaged by using
328+
> `TensorDeserializer(..., lazy_load=True)`.
329+
330+
### Drop-In PyTorch Compatibility Layer, `tensorizer.torch_compat`
331+
332+
Note that, as `tensorizer` only serializes tensors and not other Python types,
333+
it is more similar to `safetensors` than to `torch`'s own saving, as `torch`
334+
bases its serialization on the `pickle` module, which allows serialization of
335+
arbitrary Python objects.
336+
337+
The `tensorizer.torch_compat` module exists to address this and another common
338+
integration challenge:
339+
- Use case 1: You need to serialize Python objects other than tensors,
340+
like `torch.save` does.
341+
- Use case 2: You need to adapt existing code that uses `torch.save` internally
342+
where it is not easy to swap out, like in an external framework or library.
343+
344+
**`tensorizer.torch_compat` enables calls to `torch.save` and `torch.load`
345+
to use `tensorizer` as a backend for the serialization and deserialization
346+
of tensor data, separate from other data being serialized.**
347+
348+
The interface to using `tensorizer.torch_compat` is through its two context
349+
managers, `tensorizer_saving` and `tensorizer_loading`. These take similar
350+
arguments to the `TensorSerializer` and `TensorDeserializer` classes,
351+
respectively, and temporarily swap out the `torch.save` and `torch.load`
352+
functions to ones with special behaviour while their context is active.
353+
Saving this way produces two files, one for tensors, and one for all other data.
354+
355+
```py
356+
import torch
357+
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
358+
359+
model: torch.nn.Module = ...
360+
361+
state_dict = model.state_dict()
362+
363+
# Saving with torch.save, internally using tensorizer.TensorSerializer
364+
with tensorizer_saving("model.pt.tensors"):
365+
torch.save(state_dict, "model.pt")
366+
367+
# Loading with torch.load, internally using tensorizer.TensorDeserializer
368+
with tensorizer_loading("model.pt.tensors", device="cuda:0"):
369+
state_dict = torch.load("model.pt")
370+
model.load_state_dict(state_dict)
371+
```
372+
373+
For existing code that uses `torch.save` or `torch.load` internally, the
374+
recommended usage pattern is to wrap the relevant section of code in one of
375+
these context managers so that it can use `tensorizer` automatically.
376+
377+
For instance, with a `transformers.Trainer` object, part of adapting it to
378+
use `tensorizer` may be:
379+
380+
```py
381+
from tensorizer.torch_compat import tensorizer_saving
382+
383+
with tensorizer_saving():
384+
# In case this module saves references to torch.save at import time
385+
import transformers
386+
387+
trainer: transformers.Trainer = ...
388+
389+
with tensorizer_saving():
390+
# This method may call torch.save internally at some point,
391+
# so activating this context around it will intercept it when it does
392+
trainer.train()
393+
```
394+
395+
#### `torch_compat` Usage Considerations
396+
397+
If the filename to use is difficult to determine in advance, the first
398+
`file_obj` argument to `tensorizer_loading` and `tensorizer_saving` is allowed
399+
to be a callback that receives the path passed to `torch.save` and returns
400+
a place to output the sidecar `.tensors` file.
401+
402+
The `.tensors` path can be anything supported normally in `tensorizer`,
403+
including pre-opened file-like objects and `s3://` URIs.
404+
The default `file_obj` callback simply appends `.tensors` to the path.
405+
406+
```py
407+
import torch
408+
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
409+
410+
411+
def tensors_path(f: torch.types.FileLike) -> str | None:
412+
if isinstance(f, str):
413+
return f.replace(".pt", "-tensor-data.tensors", 1)
414+
else:
415+
# Returning None will save normally, without using tensorizer.
416+
# This is useful for file-like objects like io.BytesIO,
417+
# where sidecar files don't make sense.
418+
return None
419+
420+
421+
model: torch.nn.Module = ...
422+
state_dict = model.state_dict()
423+
424+
with tensorizer_saving(tensors_path):
425+
# Will save to model.pt and model-tensor-data.tensors
426+
torch.save(state_dict, "model.pt")
427+
428+
with tensorizer_loading(tensors_path, device="cuda:0"):
429+
# Will load from model.pt and model-tensor-data.tensors
430+
state_dict = torch.load("model.pt")
431+
model.load_state_dict(state_dict)
432+
```
433+
434+
The `tensorizer_saving` and `tensorizer_loading` contexts are also thread-safe
435+
and async-safe, in that their effects are local to one thread and coroutine.
436+
They may also be activated at the same time as each other, or even nested
437+
to temporarily change the arguments one is using.
438+
439+
> [!WARNING]
440+
>
441+
> Even though `tensorizer` itself only handles data and does not execute
442+
> arbitrary code, `torch.load` still uses the `pickle` module internally.
443+
> Loading untrusted `pickle` files **can** execute arbitrary code, so take
444+
> appropriate precautions when using these wrappers.
445+
>
446+
> Additionally, for technical reasons, `torch.load(..., weights_only=True)`
447+
> is incompatible with these wrappers. `weights_only` can be forced to `False`
448+
> by using `tensorizer_loading(..., suppress_weights_only=True)`,
449+
> but this disables some safety checks in `torch`, so this is opt-in only.
450+
451+
Finally, since the `tensorizer_saving` and `tensorizer_loading` contexts
452+
temporarily swap out the `torch.save` and `torch.load` functions, note that they
453+
will not affect already-saved references to those functions, e.g.:
454+
455+
```py
456+
from tensorizer.torch_compat import tensorizer_saving
457+
from torch import save as original_torch_save
458+
459+
with tensorizer_saving():
460+
# This won't work, but torch.save(..., "model.pt") would work
461+
original_torch_save(..., "model.pt")
462+
```
463+
464+
This can sometimes be worked around by wrapping import blocks
465+
in `tensorizer_saving` and/or `tensorizer_loading` as well.
466+
The wrappers will behave the same as the default `torch.save` and `torch.load`
467+
functions unless their respective contexts are active, so this will usually
468+
have no side effects.
469+
470+
For additional parameters, caveats, and advanced usage information,
471+
refer to the docstrings for `tensorizer_saving` and `tensorizer_loading` in
472+
the file [tensorizer/torch_compat.py](/tensorizer/torch_compat.py),
473+
or view their function documentation inline in an IDE.
474+
284475
## Benchmarks
285476

286477
You can run your own benchmarks on CoreWeave or your own Kubernetes cluster

tensorizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from . import serialization, stream_io, utils
1+
from . import serialization, stream_io, torch_compat, utils
22
from ._version import __version__
33
from .serialization import *
44

55
__all__ = [
66
*serialization.__all__,
77
"stream_io",
8+
"torch_compat",
89
"utils",
910
"protobuf",
1011
"tensors_pb2",

tensorizer/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.10.1"
1+
__version__ = "2.11.0a0"

0 commit comments

Comments
 (0)