You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/training/distributed_inference.md
+64Lines changed: 64 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -237,3 +237,67 @@ with torch.no_grad():
237
237
```
238
238
239
239
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
240
+
241
+
## Context parallelism
242
+
243
+
[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) reduces memory by splitting input sequences across multiple GPUs. Each GPU processes its own slice of the sequence.
244
+
245
+
The key (K) and value (V) representations are communicated between devices with [Ring Attention](https://nanotron-ultrascale-playbook.static.hf.space/index.html?section=second_optimization%3A_bucketing_gradients#ring_attention) to ensure each split can see every other token's K/V. In Ring Attention, each GPU computes attention for it's local K/V and passes it to the next GPU in the ring. This way, no single GPU has to hold the full sequence and reduces communication latency.
246
+
247
+
Call [`parallelize`] on the model and pass a [`ContextParallelConfig`]. This config supports the `ring_degree` argument which determines the number of devices to use for Ring Attention.
248
+
249
+
Use the [`~ModelMixin.set_attention_backend`] method to use a more optimized [attention backend](../optimization/attention_backends). The example below uses the FlashAttention backend.
250
+
251
+
Pass your pipelines to [`~ModelMixin.enable_parallelism`] as a context manager to activate and coordinate context parallelism.
252
+
253
+
> [!TIP]
254
+
> Context parallelism currently supports the cuDNN, FlashAttention-2, and SageAttention backends.
255
+
256
+
```py
257
+
import torch
258
+
from diffusers import QwenImagePipeline, ContextParallelConfig, enable_parallelism
Ulysses Attention splits a sequence across GPUs and performs an *all-to-all* (every device sends/receives data to every other device) so that each GPU ends up with all the tokens for only a subset of the attention heads. Each GPU computes attention locally on all tokens for its head and then performs another all-to-all to regroup the results by tokens, making it ready for the next layer.
298
+
299
+
[`ContextParallelConfig`] also supports Ulysses Attention through the `ulysses_degree` argument. This determines the number of devices to use for Ulysses Attention.
0 commit comments