You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/memory.md
+46-3Lines changed: 46 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
291
291
> [!WARNING]
292
292
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
293
293
294
-
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
295
-
296
-
The `offload_type` parameter can be set to `block_level` or `leaf_level`.
294
+
Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.
297
295
298
296
-`block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
299
297
-`leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.
300
298
299
+
Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.
300
+
301
+
<hfoptionsid="group-offloading">
302
+
<hfoptionid="pipeline">
303
+
304
+
Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.
305
+
306
+
```py
307
+
import torch
308
+
from diffusers import CogVideoXPipeline
309
+
from diffusers.hooks import apply_group_offloading
Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
Copy file name to clipboardExpand all lines: docs/source/en/quantization/overview.md
+4-1Lines changed: 4 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
34
34
> [!TIP]
35
35
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
36
36
37
-
-`components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
37
+
-`components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
38
+
39
+
`components_to_quantize` accepts either a list for multiple models or a string for a single model.
38
40
39
41
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1706
1706
processing larger images.
1707
1707
"""
1708
+
depr_message=f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
1714
1720
computing decoding in one step.
1715
1721
"""
1722
+
depr_message=f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
644
645
processing larger images.
645
646
"""
647
+
depr_message=f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
648
+
deprecate(
649
+
"enable_vae_tiling",
650
+
"0.40.0",
651
+
depr_message,
652
+
)
646
653
self.vae.enable_tiling()
647
654
648
655
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
652
659
computing decoding in one step.
653
660
"""
661
+
depr_message=f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
527
528
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
528
529
"""
530
+
depr_message=f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
531
+
deprecate(
532
+
"enable_vae_slicing",
533
+
"0.40.0",
534
+
depr_message,
535
+
)
529
536
self.vae.enable_slicing()
530
537
531
538
defdisable_vae_slicing(self):
532
539
r"""
533
540
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
534
541
computing decoding in one step.
535
542
"""
543
+
depr_message=f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
542
555
processing larger images.
543
556
"""
557
+
depr_message=f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
558
+
deprecate(
559
+
"enable_vae_tiling",
560
+
"0.40.0",
561
+
depr_message,
562
+
)
544
563
self.vae.enable_tiling()
545
564
546
565
defdisable_vae_tiling(self):
547
566
r"""
548
567
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
549
568
computing decoding in one step.
550
569
"""
570
+
depr_message=f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
703
704
processing larger images.
704
705
"""
706
+
depr_message=f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
707
+
deprecate(
708
+
"enable_vae_tiling",
709
+
"0.40.0",
710
+
depr_message,
711
+
)
705
712
self.vae.enable_tiling()
706
713
707
714
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
711
718
computing decoding in one step.
712
719
"""
720
+
depr_message=f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
721
+
deprecate(
722
+
"disable_vae_tiling",
723
+
"0.40.0",
724
+
depr_message,
725
+
)
713
726
self.vae.disable_tiling()
714
727
715
728
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
504
505
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
505
506
"""
507
+
depr_message=f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
508
+
deprecate(
509
+
"enable_vae_slicing",
510
+
"0.40.0",
511
+
depr_message,
512
+
)
506
513
self.vae.enable_slicing()
507
514
508
515
defdisable_vae_slicing(self):
509
516
r"""
510
517
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
511
518
computing decoding in one step.
512
519
"""
520
+
depr_message=f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
519
532
processing larger images.
520
533
"""
534
+
depr_message=f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
535
+
deprecate(
536
+
"enable_vae_tiling",
537
+
"0.40.0",
538
+
depr_message,
539
+
)
521
540
self.vae.enable_tiling()
522
541
523
542
defdisable_vae_tiling(self):
524
543
r"""
525
544
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
526
545
computing decoding in one step.
527
546
"""
547
+
depr_message=f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
0 commit comments