Skip to content

Commit 793ddc2

Browse files
authored
Merge branch 'main' into v0.36.0-post
2 parents 97da150 + 6708f5c commit 793ddc2

31 files changed

+4030
-118
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
237237

238238
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
239239

240+
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
241+
240242
### Ring Attention
241243

242244
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
245247

246248
```py
247249
import torch
248-
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
249-
250-
try:
251-
torch.distributed.init_process_group("nccl")
252-
rank = torch.distributed.get_rank()
253-
device = torch.device("cuda", rank % torch.cuda.device_count())
250+
from torch import distributed as dist
251+
from diffusers import DiffusionPipeline, ContextParallelConfig
252+
253+
def setup_distributed():
254+
if not dist.is_initialized():
255+
dist.init_process_group(backend="nccl")
256+
rank = dist.get_rank()
257+
device = torch.device(f"cuda:{rank}")
254258
torch.cuda.set_device(device)
255-
256-
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
257-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
258-
pipeline.transformer.set_attention_backend("flash")
259+
return device
260+
261+
def main():
262+
device = setup_distributed()
263+
world_size = dist.get_world_size()
264+
265+
pipeline = DiffusionPipeline.from_pretrained(
266+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
267+
)
268+
pipeline.transformer.set_attention_backend("_native_cudnn")
269+
270+
cp_config = ContextParallelConfig(ring_degree=world_size)
271+
pipeline.transformer.enable_parallelism(config=cp_config)
259272

260273
prompt = """
261274
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
262275
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
263276
"""
264-
277+
265278
# Must specify generator so all ranks start with same latents (or pass your own)
266279
generator = torch.Generator().manual_seed(42)
267-
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
268-
269-
if rank == 0:
270-
image.save("output.png")
271-
272-
except Exception as e:
273-
print(f"An error occurred: {e}")
274-
torch.distributed.breakpoint()
275-
raise
276-
277-
finally:
278-
if torch.distributed.is_initialized():
279-
torch.distributed.destroy_process_group()
280+
image = pipeline(
281+
prompt,
282+
guidance_scale=3.5,
283+
num_inference_steps=50,
284+
generator=generator,
285+
).images[0]
286+
287+
if dist.get_rank() == 0:
288+
image.save(f"output.png")
289+
290+
if dist.is_initialized():
291+
dist.destroy_process_group()
292+
293+
294+
if __name__ == "__main__":
295+
main()
280296
```
281297

298+
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
299+
300+
/```shell
301+
`torchrun --nproc-per-node 2 above_script.py`.
302+
/```
303+
282304
### Ulysses Attention
283305

284306
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -288,5 +310,26 @@ finally:
288310
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
289311

290312
```py
313+
# Depending on the number of GPUs available.
291314
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
315+
```
316+
317+
### parallel_config
318+
319+
Pass `parallel_config` during model initialization to enable context parallelism.
320+
321+
```py
322+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
323+
324+
cp_config = ContextParallelConfig(ring_degree=2)
325+
transformer = AutoModel.from_pretrained(
326+
CKPT_ID,
327+
subfolder="transformer",
328+
torch_dtype=torch.bfloat16,
329+
parallel_config=cp_config
330+
)
331+
332+
pipeline = DiffusionPipeline.from_pretrained(
333+
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
334+
).to(device)
292335
```

src/diffusers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@
404404
else:
405405
_import_structure["modular_pipelines"].extend(
406406
[
407+
"Flux2AutoBlocks",
408+
"Flux2ModularPipeline",
407409
"FluxAutoBlocks",
408410
"FluxKontextAutoBlocks",
409411
"FluxKontextModularPipeline",
@@ -419,6 +421,8 @@
419421
"Wan22AutoBlocks",
420422
"WanAutoBlocks",
421423
"WanModularPipeline",
424+
"ZImageAutoBlocks",
425+
"ZImageModularPipeline",
422426
]
423427
)
424428
_import_structure["pipelines"].extend(
@@ -1109,6 +1113,8 @@
11091113
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
11101114
else:
11111115
from .modular_pipelines import (
1116+
Flux2AutoBlocks,
1117+
Flux2ModularPipeline,
11121118
FluxAutoBlocks,
11131119
FluxKontextAutoBlocks,
11141120
FluxKontextModularPipeline,
@@ -1124,6 +1130,8 @@
11241130
Wan22AutoBlocks,
11251131
WanAutoBlocks,
11261132
WanModularPipeline,
1133+
ZImageAutoBlocks,
1134+
ZImageModularPipeline,
11271135
)
11281136
from .pipelines import (
11291137
AllegroPipeline,

src/diffusers/models/transformers/transformer_prx.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
from torch import nn
19-
from torch.nn.functional import fold, unfold
2019

2120
from ...configuration_utils import ConfigMixin, register_to_config
2221
from ...utils import logging
@@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
532531
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
533532
// patch_size)` is the number of patches.
534533
"""
535-
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
534+
b, c, h, w = img.shape
535+
p = patch_size
536+
537+
# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
538+
img = img.reshape(b, c, h // p, p, w // p, p)
539+
540+
# Permute to (B, H//p, W//p, C, p, p) using einsum
541+
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
542+
img = torch.einsum("nchpwq->nhwcpq", img)
543+
544+
# Flatten to (B, L, C * p * p)
545+
img = img.reshape(b, -1, c * p * p)
546+
return img
536547

537548

538549
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
@@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
554565
Reconstructed image tensor of shape `(B, C, H, W)`.
555566
"""
556567
if isinstance(shape, tuple):
557-
shape = shape[-2:]
568+
h, w = shape[-2:]
558569
elif isinstance(shape, torch.Tensor):
559-
shape = (int(shape[0]), int(shape[1]))
570+
h, w = (int(shape[0]), int(shape[1]))
560571
else:
561572
raise NotImplementedError(f"shape type {type(shape)} not supported")
562-
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
573+
574+
b, l, d = seq.shape
575+
p = patch_size
576+
c = d // (p * p)
577+
578+
# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
579+
seq = seq.reshape(b, h // p, w // p, c, p, p)
580+
581+
# Permute back to image layout: (B, C, H//p, p, W//p, p)
582+
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
583+
seq = torch.einsum("nhwcpq->nchpwq", seq)
584+
585+
# Final reshape to (B, C, H, W)
586+
seq = seq.reshape(b, c, h, w)
587+
return seq
563588

564589

565590
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):

src/diffusers/modular_pipelines/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
"FluxKontextAutoBlocks",
5353
"FluxKontextModularPipeline",
5454
]
55+
_import_structure["flux2"] = [
56+
"Flux2AutoBlocks",
57+
"Flux2ModularPipeline",
58+
]
5559
_import_structure["qwenimage"] = [
5660
"QwenImageAutoBlocks",
5761
"QwenImageModularPipeline",
@@ -60,6 +64,10 @@
6064
"QwenImageEditPlusModularPipeline",
6165
"QwenImageEditPlusAutoBlocks",
6266
]
67+
_import_structure["z_image"] = [
68+
"ZImageAutoBlocks",
69+
"ZImageModularPipeline",
70+
]
6371
_import_structure["components_manager"] = ["ComponentsManager"]
6472

6573
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -71,6 +79,7 @@
7179
else:
7280
from .components_manager import ComponentsManager
7381
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
82+
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
7483
from .modular_pipeline import (
7584
AutoPipelineBlocks,
7685
BlockState,
@@ -91,6 +100,7 @@
91100
)
92101
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
93102
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
103+
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
94104
else:
95105
import sys
96106

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["encoders"] = [
25+
"Flux2TextEncoderStep",
26+
"Flux2RemoteTextEncoderStep",
27+
"Flux2VaeEncoderStep",
28+
]
29+
_import_structure["before_denoise"] = [
30+
"Flux2SetTimestepsStep",
31+
"Flux2PrepareLatentsStep",
32+
"Flux2RoPEInputsStep",
33+
"Flux2PrepareImageLatentsStep",
34+
]
35+
_import_structure["denoise"] = [
36+
"Flux2LoopDenoiser",
37+
"Flux2LoopAfterDenoiser",
38+
"Flux2DenoiseLoopWrapper",
39+
"Flux2DenoiseStep",
40+
]
41+
_import_structure["decoders"] = ["Flux2DecodeStep"]
42+
_import_structure["inputs"] = [
43+
"Flux2ProcessImagesInputStep",
44+
"Flux2TextInputStep",
45+
]
46+
_import_structure["modular_blocks"] = [
47+
"ALL_BLOCKS",
48+
"AUTO_BLOCKS",
49+
"REMOTE_AUTO_BLOCKS",
50+
"TEXT2IMAGE_BLOCKS",
51+
"IMAGE_CONDITIONED_BLOCKS",
52+
"Flux2AutoBlocks",
53+
"Flux2AutoVaeEncoderStep",
54+
"Flux2BeforeDenoiseStep",
55+
"Flux2VaeEncoderSequentialStep",
56+
]
57+
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
58+
59+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
60+
try:
61+
if not (is_transformers_available() and is_torch_available()):
62+
raise OptionalDependencyNotAvailable()
63+
except OptionalDependencyNotAvailable:
64+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
65+
else:
66+
from .before_denoise import (
67+
Flux2PrepareImageLatentsStep,
68+
Flux2PrepareLatentsStep,
69+
Flux2RoPEInputsStep,
70+
Flux2SetTimestepsStep,
71+
)
72+
from .decoders import Flux2DecodeStep
73+
from .denoise import (
74+
Flux2DenoiseLoopWrapper,
75+
Flux2DenoiseStep,
76+
Flux2LoopAfterDenoiser,
77+
Flux2LoopDenoiser,
78+
)
79+
from .encoders import (
80+
Flux2RemoteTextEncoderStep,
81+
Flux2TextEncoderStep,
82+
Flux2VaeEncoderStep,
83+
)
84+
from .inputs import (
85+
Flux2ProcessImagesInputStep,
86+
Flux2TextInputStep,
87+
)
88+
from .modular_blocks import (
89+
ALL_BLOCKS,
90+
AUTO_BLOCKS,
91+
IMAGE_CONDITIONED_BLOCKS,
92+
REMOTE_AUTO_BLOCKS,
93+
TEXT2IMAGE_BLOCKS,
94+
Flux2AutoBlocks,
95+
Flux2AutoVaeEncoderStep,
96+
Flux2BeforeDenoiseStep,
97+
Flux2VaeEncoderSequentialStep,
98+
)
99+
from .modular_pipeline import Flux2ModularPipeline
100+
else:
101+
import sys
102+
103+
sys.modules[__name__] = _LazyModule(
104+
__name__,
105+
globals()["__file__"],
106+
_import_structure,
107+
module_spec=__spec__,
108+
)
109+
110+
for name, value in _dummy_objects.items():
111+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)