@@ -281,6 +281,197 @@ An example command line tool to add or remove encryption from existing
281281serialized models is also available as
282282[ examples/encryption.py] ( examples/encrypt_existing.py ) .
283283
284+ ## PyTorch Compatibility
285+
286+ ` tensorizer ` 's ` TensorSerializer ` and ` TensorDeserializer ` classes are designed
287+ to be able to replace the use of ` torch.save ` and ` torch.load ` in model saving
288+ and loading pipelines, however, they are not drop-in replacements. The API for
289+ serialization and deserialization with ` tensorizer ` offer more parameters to
290+ control performance, resource usage, and additional features like encryption,
291+ so they are invoked differently.
292+ For drop-in replacements, see the next section.
293+
294+ The examples below show example usages of
295+ ` torch.save ` and ` torch.load ` , and how they may be replaced with ` tensorizer `
296+ serialization.
297+
298+ ``` py
299+ from tensorizer import TensorDeserializer, TensorSerializer
300+ import torch
301+
302+ model: torch.nn.Module = ...
303+
304+ # Saving with torch.save
305+ state_dict = model.state_dict()
306+ torch.save(state_dict, " model.pt" )
307+
308+ # Loading with torch.load
309+ state_dict = torch.load(" model.pt" , map_location = " cuda:0" )
310+ model.load_state_dict(state_dict)
311+
312+ # Saving with tensorizer.TensorSerializer
313+ state_dict = model.state_dict()
314+ serializer = TensorSerializer(" model.tensors" )
315+ serializer.write_state_dict(state_dict)
316+ serializer.close()
317+
318+ # Loading with tensorizer.TensorDeserializer
319+ with TensorDeserializer(" model.tensors" , device = " cuda:0" ) as state_dict:
320+ model.load_state_dict(state_dict)
321+ ```
322+
323+ > [ !NOTE]
324+ >
325+ > ` TensorDeserializer ` is a context manager because it supports lazy-loading,
326+ > where the context controls how long its source file will remain open to read
327+ > more tensors. This behaviour is optional and can be engaged by using
328+ > ` TensorDeserializer(..., lazy_load=True) ` .
329+
330+ ### Drop-In PyTorch Compatibility Layer, ` tensorizer.torch_compat `
331+
332+ Note that, as ` tensorizer ` only serializes tensors and not other Python types,
333+ it is more similar to ` safetensors ` than to ` torch ` 's own saving, as ` torch `
334+ bases its serialization on the ` pickle ` module, which allows serialization of
335+ arbitrary Python objects.
336+
337+ The ` tensorizer.torch_compat ` module exists to address this and another common
338+ integration challenge:
339+ - Use case 1: You need to serialize Python objects other than tensors,
340+ like ` torch.save ` does.
341+ - Use case 2: You need to adapt existing code that uses ` torch.save ` internally
342+ where it is not easy to swap out, like in an external framework or library.
343+
344+ ** ` tensorizer.torch_compat ` enables calls to ` torch.save ` and ` torch.load `
345+ to use ` tensorizer ` as a backend for the serialization and deserialization
346+ of tensor data, separate from other data being serialized.**
347+
348+ The interface to using ` tensorizer.torch_compat ` is through its two context
349+ managers, ` tensorizer_saving ` and ` tensorizer_loading ` . These take similar
350+ arguments to the ` TensorSerializer ` and ` TensorDeserializer ` classes,
351+ respectively, and temporarily swap out the ` torch.save ` and ` torch.load `
352+ functions to ones with special behaviour while their context is active.
353+ Saving this way produces two files, one for tensors, and one for all other data.
354+
355+ ``` py
356+ import torch
357+ from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
358+
359+ model: torch.nn.Module = ...
360+
361+ state_dict = model.state_dict()
362+
363+ # Saving with torch.save, internally using tensorizer.TensorSerializer
364+ with tensorizer_saving(" model.pt.tensors" ):
365+ torch.save(state_dict, " model.pt" )
366+
367+ # Loading with torch.load, internally using tensorizer.TensorDeserializer
368+ with tensorizer_loading(" model.pt.tensors" , device = " cuda:0" ):
369+ state_dict = torch.load(" model.pt" )
370+ model.load_state_dict(state_dict)
371+ ```
372+
373+ For existing code that uses ` torch.save ` or ` torch.load ` internally, the
374+ recommended usage pattern is to wrap the relevant section of code in one of
375+ these context managers so that it can use ` tensorizer ` automatically.
376+
377+ For instance, with a ` transformers.Trainer ` object, part of adapting it to
378+ use ` tensorizer ` may be:
379+
380+ ``` py
381+ from tensorizer.torch_compat import tensorizer_saving
382+
383+ with tensorizer_saving():
384+ # In case this module saves references to torch.save at import time
385+ import transformers
386+
387+ trainer: transformers.Trainer = ...
388+
389+ with tensorizer_saving():
390+ # This method may call torch.save internally at some point,
391+ # so activating this context around it will intercept it when it does
392+ trainer.train()
393+ ```
394+
395+ #### ` torch_compat ` Usage Considerations
396+
397+ If the filename to use is difficult to determine in advance, the first
398+ ` file_obj ` argument to ` tensorizer_loading ` and ` tensorizer_saving ` is allowed
399+ to be a callback that receives the path passed to ` torch.save ` and returns
400+ a place to output the sidecar ` .tensors ` file.
401+
402+ The ` .tensors ` path can be anything supported normally in ` tensorizer ` ,
403+ including pre-opened file-like objects and ` s3:// ` URIs.
404+ The default ` file_obj ` callback simply appends ` .tensors ` to the path.
405+
406+ ``` py
407+ import torch
408+ from tensorizer.torch_compat import tensorizer_loading, tensorizer_saving
409+
410+
411+ def tensors_path (f : torch.types.FileLike) -> str | None :
412+ if isinstance (f, str ):
413+ return f.replace(" .pt" , " -tensor-data.tensors" , 1 )
414+ else :
415+ # Returning None will save normally, without using tensorizer.
416+ # This is useful for file-like objects like io.BytesIO,
417+ # where sidecar files don't make sense.
418+ return None
419+
420+
421+ model: torch.nn.Module = ...
422+ state_dict = model.state_dict()
423+
424+ with tensorizer_saving(tensors_path):
425+ # Will save to model.pt and model-tensor-data.tensors
426+ torch.save(state_dict, " model.pt" )
427+
428+ with tensorizer_loading(tensors_path, device = " cuda:0" ):
429+ # Will load from model.pt and model-tensor-data.tensors
430+ state_dict = torch.load(" model.pt" )
431+ model.load_state_dict(state_dict)
432+ ```
433+
434+ The ` tensorizer_saving ` and ` tensorizer_loading ` contexts are also thread-safe
435+ and async-safe, in that their effects are local to one thread and coroutine.
436+ They may also be activated at the same time as each other, or even nested
437+ to temporarily change the arguments one is using.
438+
439+ > [ !WARNING]
440+ >
441+ > Even though ` tensorizer ` itself only handles data and does not execute
442+ > arbitrary code, ` torch.load ` still uses the ` pickle ` module internally.
443+ > Loading untrusted ` pickle ` files ** can** execute arbitrary code, so take
444+ > appropriate precautions when using these wrappers.
445+ >
446+ > Additionally, for technical reasons, ` torch.load(..., weights_only=True) `
447+ > is incompatible with these wrappers. ` weights_only ` can be forced to ` False `
448+ > by using ` tensorizer_loading(..., suppress_weights_only=True) ` ,
449+ > but this disables some safety checks in ` torch ` , so this is opt-in only.
450+
451+ Finally, since the ` tensorizer_saving ` and ` tensorizer_loading ` contexts
452+ temporarily swap out the ` torch.save ` and ` torch.load ` functions, note that they
453+ will not affect already-saved references to those functions, e.g.:
454+
455+ ``` py
456+ from tensorizer.torch_compat import tensorizer_saving
457+ from torch import save as original_torch_save
458+
459+ with tensorizer_saving():
460+ # This won't work, but torch.save(..., "model.pt") would work
461+ original_torch_save(... , " model.pt" )
462+ ```
463+
464+ This can sometimes be worked around by wrapping import blocks
465+ in ` tensorizer_saving ` and/or ` tensorizer_loading ` as well.
466+ The wrappers will behave the same as the default ` torch.save ` and ` torch.load `
467+ functions unless their respective contexts are active, so this will usually
468+ have no side effects.
469+
470+ For additional parameters, caveats, and advanced usage information,
471+ refer to the docstrings for ` tensorizer_saving ` and ` tensorizer_loading ` in
472+ the file [ tensorizer/torch_compat.py] ( /tensorizer/torch_compat.py ) ,
473+ or view their function documentation inline in an IDE.
474+
284475## Benchmarks
285476
286477You can run your own benchmarks on CoreWeave or your own Kubernetes cluster
0 commit comments