You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/memory.md
+73-20Lines changed: 73 additions & 20 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,7 +17,7 @@ Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipe
17
17
This guide will show you how to reduce your memory usage.
18
18
19
19
> [!TIP]
20
-
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
20
+
> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.
65
65
66
-
The `device_map` parameter controls how the model components in a pipeline are distributed across devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.
66
+
The `device_map` parameter controls how the model components in a pipeline or the layers in an individual model are distributed across devices.
67
+
68
+
<hfoptionsid="device-map">
69
+
<hfoptionid="pipeline level">
70
+
71
+
The `balanced` device placement strategy evenly splits the pipeline across all available devices.
The `device_map` parameter also works on the model-level. This is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Instead of `balanced`, set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
91
+
</hfoption>
92
+
<hfoptionid="model level">
93
+
94
+
The `device_map` is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
For more fine-grained control, pass a dictionary to enforce the maximum GPU memory to use on each device. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
108
+
You can inspect a model's device map with `hf_device_map`.
109
+
110
+
```py
111
+
print(transformer.hf_device_map)
112
+
```
113
+
114
+
</hfoption>
115
+
</hfoptions>
116
+
117
+
When designing your own `device_map`, it should be a dictionary of a model's specific module name or layer and a device identifier (an integer for GPUs, `cpu` for CPUs, and `disk` for disk).
118
+
119
+
Call `hf_device_map` on a model to see how model layers are distributed and then design your own.
Pass a dictionary mapping maximum memory usage to each device to enforce a limit. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
174
218
175
-
## CPU offloading
219
+
## Offloading
220
+
221
+
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
222
+
223
+
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
224
+
225
+
### CPU offloading
176
226
177
227
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
209
259
@@ -219,7 +269,7 @@ from diffusers import DiffusionPipeline
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
236
286
237
-
## Group offloading
287
+
###Group offloading
238
288
239
289
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
297
347
298
-
<Tip>
348
+
#### Offloading to disk
349
+
350
+
Group offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.
299
351
300
-
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
301
-
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
352
+
Set the `offload_to_disk_path` argument in either [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`] to offload the model to the disk.
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
308
-
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
309
-
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
310
-
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
360
+
Refer to these [two](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363)[tables](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) to compare the speed and memory trade-offs.
311
361
312
362
## Layerwise casting
313
363
364
+
> [!TIP]
365
+
> Combine layerwise casting with [group offloading](#group-offloading) for even more memory savings.
366
+
314
367
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
315
368
316
369
> [!WARNING]
@@ -500,7 +553,7 @@ with torch.inference_mode():
500
553
## Memory-efficient attention
501
554
502
555
> [!TIP]
503
-
> Memory-efficient attention optimizes for memory usage *and*[inference speed](./fp16#scaled-dot-product-attention!
556
+
> Memory-efficient attention optimizes for memory usage *and*[inference speed](./fp16#scaled-dot-product-attention)!
504
557
505
558
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
0 commit comments