Skip to content

Commit 07ea078

Browse files
authored
[Modular]z-image (#12808)
* initiL * up up * fix: z_image -> z-image * style * copy * fix more * some docstring fix
1 parent 54fa074 commit 07ea078

File tree

12 files changed

+1730
-2
lines changed

12 files changed

+1730
-2
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,8 @@
419419
"Wan22AutoBlocks",
420420
"WanAutoBlocks",
421421
"WanModularPipeline",
422+
"ZImageAutoBlocks",
423+
"ZImageModularPipeline",
422424
]
423425
)
424426
_import_structure["pipelines"].extend(
@@ -1124,6 +1126,8 @@
11241126
Wan22AutoBlocks,
11251127
WanAutoBlocks,
11261128
WanModularPipeline,
1129+
ZImageAutoBlocks,
1130+
ZImageModularPipeline,
11271131
)
11281132
from .pipelines import (
11291133
AllegroPipeline,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
"QwenImageEditPlusModularPipeline",
6161
"QwenImageEditPlusAutoBlocks",
6262
]
63+
_import_structure["z_image"] = [
64+
"ZImageAutoBlocks",
65+
"ZImageModularPipeline",
66+
]
6367
_import_structure["components_manager"] = ["ComponentsManager"]
6468

6569
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -91,6 +95,7 @@
9195
)
9296
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
9397
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
98+
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
9499
else:
95100
import sys
96101

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
("qwenimage", "QwenImageModularPipeline"),
6262
("qwenimage-edit", "QwenImageEditModularPipeline"),
6363
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
64+
("z-image", "ZImageModularPipeline"),
6465
]
6566
)
6667

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
530530

531531
device = components._execution_device
532532
dtype = torch.float32
533+
vae_dtype = components.vae.dtype
533534

534535
height = block_state.height or components.default_height
535536
width = block_state.width or components.default_width
@@ -555,7 +556,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
555556
vae=components.vae,
556557
generator=block_state.generator,
557558
device=device,
558-
dtype=dtype,
559+
dtype=vae_dtype,
559560
latent_channels=components.num_channels_latents,
560561
)
561562

@@ -627,6 +628,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
627628

628629
device = components._execution_device
629630
dtype = torch.float32
631+
vae_dtype = components.vae.dtype
630632

631633
height = block_state.height or components.default_height
632634
width = block_state.width or components.default_width
@@ -659,7 +661,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
659661
vae=components.vae,
660662
generator=block_state.generator,
661663
device=device,
662-
dtype=dtype,
664+
dtype=vae_dtype,
663665
latent_channels=components.num_channels_latents,
664666
)
665667

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
25+
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
26+
_import_structure["modular_blocks"] = [
27+
"ALL_BLOCKS",
28+
"ZImageAutoBlocks",
29+
]
30+
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
31+
32+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
33+
try:
34+
if not (is_transformers_available() and is_torch_available()):
35+
raise OptionalDependencyNotAvailable()
36+
except OptionalDependencyNotAvailable:
37+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
38+
else:
39+
from .decoders import ZImageVaeDecoderStep
40+
from .encoders import ZImageTextEncoderStep
41+
from .modular_blocks import (
42+
ALL_BLOCKS,
43+
ZImageAutoBlocks,
44+
)
45+
from .modular_pipeline import ZImageModularPipeline
46+
else:
47+
import sys
48+
49+
sys.modules[__name__] = _LazyModule(
50+
__name__,
51+
globals()["__file__"],
52+
_import_structure,
53+
module_spec=__spec__,
54+
)
55+
56+
for name, value in _dummy_objects.items():
57+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)