Skip to content

Commit ac7f9fc

Browse files
committed
init
1 parent 55f0b3d commit ac7f9fc

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,67 @@ with torch.no_grad():
237237
```
238238

239239
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
240+
241+
## Context parallelism
242+
243+
[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) reduces memory by splitting input sequences across multiple GPUs. Each GPU processes its own slice of the sequence.
244+
245+
The key (K) and value (V) representations are communicated between devices with [Ring Attention](https://nanotron-ultrascale-playbook.static.hf.space/index.html?section=second_optimization%3A_bucketing_gradients#ring_attention) to ensure each split can see every other token's K/V. In Ring Attention, each GPU computes attention for it's local K/V and passes it to the next GPU in the ring. This way, no single GPU has to hold the full sequence and reduces communication latency.
246+
247+
Call [`parallelize`] on the model and pass a [`ContextParallelConfig`]. This config supports the `ring_degree` argument which determines the number of devices to use for Ring Attention.
248+
249+
Use the [`~ModelMixin.set_attention_backend`] method to use a more optimized [attention backend](../optimization/attention_backends). The example below uses the FlashAttention backend.
250+
251+
Pass your pipelines to [`~ModelMixin.enable_parallelism`] as a context manager to activate and coordinate context parallelism.
252+
253+
> [!TIP]
254+
> Context parallelism currently supports the cuDNN, FlashAttention-2, and SageAttention backends.
255+
256+
```py
257+
import torch
258+
from diffusers import QwenImagePipeline, ContextParallelConfig, enable_parallelism
259+
260+
try:
261+
torch.distributed.init_process_group("nccl")
262+
rank = torch.distributed.get_rank()
263+
device = torch.device("cuda", rank % torch.cuda.device_count())
264+
torch.cuda.set_device(device)
265+
266+
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
267+
pipeline.to("cuda")
268+
269+
pipeline.transformer.parallelize(config=ContextParallelConfig(ring_degree=2))
270+
pipeline.transformer.set_attention_backend("flash")
271+
272+
prompt = """
273+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
274+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
275+
"""
276+
277+
# Must specify generator so all ranks start with same latents (or pass your own)
278+
generator = torch.Generator().manual_seed(42)
279+
with enable_parallelism(pipeline):
280+
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
281+
282+
if rank == 0:
283+
image.save("output.png")
284+
285+
except Exception as e:
286+
print(f"An error occurred: {e}")
287+
torch.distributed.breakpoint()
288+
raise
289+
290+
finally:
291+
if torch.distributed.is_initialized():
292+
torch.distributed.destroy_process_group()
293+
```
294+
295+
### Ulysses Attention
296+
297+
Ulysses Attention splits a sequence across GPUs and performs an *all-to-all* (every device sends/receives data to every other device) so that each GPU ends up with all the tokens for only a subset of the attention heads. Each GPU computes attention locally on all tokens for its head and then performs another all-to-all to regroup the results by tokens, making it ready for the next layer.
298+
299+
[`ContextParallelConfig`] also supports Ulysses Attention through the `ulysses_degree` argument. This determines the number of devices to use for Ulysses Attention.
300+
301+
```py
302+
pipeline.transformer.parallelize(config=ContextParallelConfig(ulysses_degree=2))
303+
```

0 commit comments

Comments
 (0)