Skip to content

Commit e85d9f1

Browse files
committed
docs(torch_compat): Document the torch_compat module in the README
1 parent 816fd32 commit e85d9f1

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

README.md

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,197 @@ An example command line tool to add or remove encryption from existing
281281
serialized models is also available as
282282
[examples/encryption.py](examples/encrypt_existing.py).
283283

284+
## PyTorch Compatibility
285+
286+
`tensorizer`'s `TensorSerializer` and `TensorDeserializer` classes are designed
287+
to be able to replace the use of `torch.save` and `torch.load` in model saving
288+
and loading pipelines, however, they are not drop-in replacements. The API for
289+
serialization and deserialization with `tensorizer` offer more parameters to
290+
control performance, resource usage, and additional features like encryption,
291+
so they are invoked differently.
292+
For drop-in replacements, see the next section.
293+
294+
The examples below show example usages of
295+
`torch.save` and `torch.load`, and how they may be replaced with `tensorizer`
296+
serialization.
297+
298+
```py
299+
from tensorizer import TensorDeserializer, TensorSerializer
300+
import torch
301+
302+
model: torch.nn.Module = ...
303+
304+
# Saving with torch.save
305+
state_dict = model.state_dict()
306+
torch.save(state_dict, "model.pt")
307+
308+
# Loading with torch.load
309+
state_dict = torch.load("model.pt", map_location="cuda:0")
310+
model.load_state_dict(state_dict)
311+
312+
# Saving with tensorizer.TensorSerializer
313+
state_dict = model.state_dict()
314+
serializer = TensorSerializer("model.tensors")
315+
serializer.write_state_dict(state_dict)
316+
serializer.close()
317+
318+
# Loading with tensorizer.TensorDeserializer
319+
with TensorDeserializer("model.tensors", device="cuda:0") as state_dict:
320+
model.load_state_dict(state_dict)
321+
```
322+
323+
> [!NOTE]
324+
>
325+
> `TensorDeserializer` is a context manager because it supports lazy-loading,
326+
> where the context controls how long its source file will remain open to read
327+
> more tensors. This behaviour is optional and can be engaged by using
328+
> `TensorDeserializer(..., lazy_load=True)`.
329+
330+
### Drop-In PyTorch Compatibility Layer, `tensorizer.torch_compat`
331+
332+
Note that, as `tensorizer` only serializes tensors and not other Python types,
333+
it is more similar to `safetensors` than to `torch`'s own saving, as `torch`
334+
bases its serialization on the `pickle` module, which allows serialization of
335+
arbitrary Python objects.
336+
337+
The `tensorizer.torch_compat` module exists to address this and another common
338+
integration challenge:
339+
- Use case 1: You need to serialize Python objects other than tensors,
340+
like `torch.save` does.
341+
- Use case 2: You need to adapt existing code that uses `torch.save` internally
342+
where it is not easy to swap out, like in an external framework or library.
343+
344+
**`tensorizer.torch_compat` enables calls to `torch.save` and `torch.load`
345+
to use `tensorizer` as a backend for the serialization and deserialization
346+
of tensor data, separate from other data being serialized.**
347+
348+
The interface to using `tensorizer.torch_compat` is through its two context
349+
managers, `tensorizer_saving` and `tensorizer_loading`. These take similar
350+
arguments to the `TensorSerializer` and `TensorDeserializer` classes,
351+
respectively, and temporarily swap out the `torch.save` and `torch.load`
352+
functions to ones with special behaviour while their context is active.
353+
Saving this way produces two files, one for tensors, and one for all other data.
354+
355+
```py
356+
import torch
357+
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
358+
359+
model: torch.nn.Module = ...
360+
361+
state_dict = model.state_dict()
362+
363+
# Saving with torch.save, internally using tensorizer.TensorSerializer
364+
with tensorizer_saving("model.pt.tensors"):
365+
torch.save(state_dict, "model.pt")
366+
367+
# Loading with torch.load, internally using tensorizer.TensorDeserializer
368+
with tensorizer_loading("model.pt.tensors", device="cuda:0"):
369+
state_dict = torch.load("model.pt")
370+
model.load_state_dict(state_dict)
371+
```
372+
373+
For existing code that uses `torch.save` or `torch.load` internally, the
374+
recommended usage pattern is to wrap the relevant section of code in one of
375+
these context managers so that it can use `tensorizer` automatically.
376+
377+
For instance, with a `transformers.Trainer` object, part of adapting it to
378+
use `tensorizer` may be:
379+
380+
```py
381+
from tensorizer.torch_compat import tensorizer_saving
382+
383+
with tensorizer_saving():
384+
# In case this module saves references to torch.save at import time
385+
import transformers
386+
387+
trainer: transformers.Trainer = ...
388+
389+
with tensorizer_saving():
390+
# This method may call torch.save internally at some point,
391+
# so activating this context around it will intercept it when it does
392+
trainer.train()
393+
```
394+
395+
#### `torch_compat` Usage Considerations
396+
397+
If the filename to use is difficult to determine in advance, the first
398+
`file_obj` argument to `tensorizer_loading` and `tensorizer_saving` is allowed
399+
to be a callback that receives the path passed to `torch.save` and returns
400+
a place to output the sidecar `.tensors` file.
401+
402+
The `.tensors` path can be anything supported normally in `tensorizer`,
403+
including pre-opened file-like objects and `s3://` URIs.
404+
The default `file_obj` callback simply appends `.tensors` to the path.
405+
406+
```py
407+
import torch
408+
from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
409+
410+
411+
def tensors_path(f: torch.types.FileLike) -> str | None:
412+
if isinstance(f, str):
413+
return f.replace(".pt", "-tensor-data.tensors", 1)
414+
else:
415+
# Returning None will save normally, without using tensorizer.
416+
# This is useful for file-like objects like io.BytesIO,
417+
# where sidecar files don't make sense.
418+
return None
419+
420+
421+
model: torch.nn.Module = ...
422+
state_dict = model.state_dict()
423+
424+
with tensorizer_saving(tensors_path):
425+
# Will save to model.pt and model-tensor-data.tensors
426+
torch.save(state_dict, "model.pt")
427+
428+
with tensorizer_loading(tensors_path, device="cuda:0"):
429+
# Will load from model.pt and model-tensor-data.tensors
430+
state_dict = torch.load("model.pt")
431+
model.load_state_dict(state_dict)
432+
```
433+
434+
The `tensorizer_saving` and `tensorizer_loading` contexts are also thread-safe
435+
and async-safe, in that their effects are local to one thread and coroutine.
436+
They may also be activated at the same time as each other, or even nested
437+
to temporarily change the arguments one is using.
438+
439+
> [!WARNING]
440+
>
441+
> Even though `tensorizer` itself only handles data and does not execute
442+
> arbitrary code, `torch.load` still uses the `pickle` module internally.
443+
> Loading untrusted `pickle` files **can** execute arbitrary code, so take
444+
> appropriate precautions when using these wrappers.
445+
>
446+
> Additionally, for technical reasons, `torch.load(..., weights_only=True)`
447+
> is incompatible with these wrappers. `weights_only` can be forced to `False`
448+
> by using `tensorizer_loading(..., suppress_weights_only=True)`,
449+
> but this disables some safety checks in `torch`, so this is opt-in only.
450+
451+
Finally, since the `tensorizer_saving` and `tensorizer_loading` contexts
452+
temporarily swap out the `torch.save` and `torch.load` functions, note that they
453+
will not affect already-saved references to those functions, e.g.:
454+
455+
```py
456+
from tensorizer.torch_compat import tensorizer_saving
457+
from torch import save as original_torch_save
458+
459+
with tensorizer_saving():
460+
# This won't work, but torch.save(..., "model.pt") would work
461+
original_torch_save(..., "model.pt")
462+
```
463+
464+
This can sometimes be worked around by wrapping import blocks
465+
in `tensorizer_saving` and/or `tensorizer_loading` as well.
466+
The wrappers will behave the same as the default `torch.save` and `torch.load`
467+
functions unless their respective contexts are active, so this will usually
468+
have no side effects.
469+
470+
For additional parameters, caveats, and advanced usage information,
471+
refer to the docstrings for `tensorizer_saving` and `tensorizer_loading` in
472+
the file `tensorizer/torch_compat.py`, or view their function documentation
473+
inline in an IDE.
474+
284475
## Benchmarks
285476

286477
You can run your own benchmarks on CoreWeave or your own Kubernetes cluster

0 commit comments

Comments
 (0)