Skip to content

Commit 6708f5c

Browse files
sayakpaulstevhliu
andauthored
[docs] improve distributed inference cp docs. (#12810)
* improve distributed inference cp docs. * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> --------- Co-authored-by: Steven Liu <[email protected]>
1 parent be3c2a0 commit 6708f5c

File tree

1 file changed

+67
-24
lines changed

1 file changed

+67
-24
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
237237

238238
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
239239

240+
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
241+
240242
### Ring Attention
241243

242244
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
245247

246248
```py
247249
import torch
248-
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
249-
250-
try:
251-
torch.distributed.init_process_group("nccl")
252-
rank = torch.distributed.get_rank()
253-
device = torch.device("cuda", rank % torch.cuda.device_count())
250+
from torch import distributed as dist
251+
from diffusers import DiffusionPipeline, ContextParallelConfig
252+
253+
def setup_distributed():
254+
if not dist.is_initialized():
255+
dist.init_process_group(backend="nccl")
256+
rank = dist.get_rank()
257+
device = torch.device(f"cuda:{rank}")
254258
torch.cuda.set_device(device)
255-
256-
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
257-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
258-
pipeline.transformer.set_attention_backend("flash")
259+
return device
260+
261+
def main():
262+
device = setup_distributed()
263+
world_size = dist.get_world_size()
264+
265+
pipeline = DiffusionPipeline.from_pretrained(
266+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
267+
)
268+
pipeline.transformer.set_attention_backend("_native_cudnn")
269+
270+
cp_config = ContextParallelConfig(ring_degree=world_size)
271+
pipeline.transformer.enable_parallelism(config=cp_config)
259272

260273
prompt = """
261274
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
262275
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
263276
"""
264-
277+
265278
# Must specify generator so all ranks start with same latents (or pass your own)
266279
generator = torch.Generator().manual_seed(42)
267-
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
268-
269-
if rank == 0:
270-
image.save("output.png")
271-
272-
except Exception as e:
273-
print(f"An error occurred: {e}")
274-
torch.distributed.breakpoint()
275-
raise
276-
277-
finally:
278-
if torch.distributed.is_initialized():
279-
torch.distributed.destroy_process_group()
280+
image = pipeline(
281+
prompt,
282+
guidance_scale=3.5,
283+
num_inference_steps=50,
284+
generator=generator,
285+
).images[0]
286+
287+
if dist.get_rank() == 0:
288+
image.save(f"output.png")
289+
290+
if dist.is_initialized():
291+
dist.destroy_process_group()
292+
293+
294+
if __name__ == "__main__":
295+
main()
280296
```
281297

298+
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
299+
300+
/```shell
301+
`torchrun --nproc-per-node 2 above_script.py`.
302+
/```
303+
282304
### Ulysses Attention
283305

284306
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -288,5 +310,26 @@ finally:
288310
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
289311

290312
```py
313+
# Depending on the number of GPUs available.
291314
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
315+
```
316+
317+
### parallel_config
318+
319+
Pass `parallel_config` during model initialization to enable context parallelism.
320+
321+
```py
322+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
323+
324+
cp_config = ContextParallelConfig(ring_degree=2)
325+
transformer = AutoModel.from_pretrained(
326+
CKPT_ID,
327+
subfolder="transformer",
328+
torch_dtype=torch.bfloat16,
329+
parallel_config=cp_config
330+
)
331+
332+
pipeline = DiffusionPipeline.from_pretrained(
333+
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
334+
).to(device)
292335
```

0 commit comments

Comments
 (0)