Skip to content

Commit 64415ab

Browse files
committed
qwen modular refactor, unpack before decode
1 parent 1d42bb2 commit 64415ab

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

src/diffusers/modular_pipelines/qwenimage/decoders.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,45 @@
2929

3030
logger = logging.get_logger(__name__)
3131

32+
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
33+
model_name = "qwenimage"
34+
35+
@property
36+
def description(self) -> str:
37+
return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)"
38+
39+
@property
40+
def expected_components(self) -> List[ComponentSpec]:
41+
components = [
42+
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
43+
]
44+
45+
return components
46+
47+
@property
48+
def inputs(self) -> List[InputParam]:
49+
return [
50+
InputParam(name="height", required=True),
51+
InputParam(name="width", required=True),
52+
InputParam(
53+
name="latents",
54+
required=True,
55+
type_hint=torch.Tensor,
56+
description="The latents to decode, can be generated in the denoise step",
57+
),
58+
]
59+
60+
@torch.no_grad()
61+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
62+
block_state = self.get_block_state(state)
63+
64+
vae_scale_factor = components.vae_scale_factor
65+
block_state.latents = components.pachifier.unpack_latents(
66+
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
67+
)
68+
69+
self.set_block_state(state, block_state)
70+
return components, state
3271

3372
class QwenImageDecoderStep(ModularPipelineBlocks):
3473
model_name = "qwenimage"
@@ -41,16 +80,13 @@ def description(self) -> str:
4180
def expected_components(self) -> List[ComponentSpec]:
4281
components = [
4382
ComponentSpec("vae", AutoencoderKLQwenImage),
44-
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
4583
]
4684

4785
return components
4886

4987
@property
5088
def inputs(self) -> List[InputParam]:
5189
return [
52-
InputParam(name="height", required=True),
53-
InputParam(name="width", required=True),
5490
InputParam(
5591
name="latents",
5692
required=True,
@@ -74,10 +110,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
74110
block_state = self.get_block_state(state)
75111

76112
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
77-
vae_scale_factor = components.vae_scale_factor
78-
block_state.latents = components.pachifier.unpack_latents(
79-
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
80-
)
113+
if block_state.latents.ndim == 4:
114+
block_state.latents = block_state.latents.unsqueeze(dim=1)
115+
elif block_state.latents.ndim != 5:
116+
raise ValueError(f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step.")
81117
block_state.latents = block_state.latents.to(components.vae.dtype)
82118

83119
latents_mean = (

src/diffusers/modular_pipelines/qwenimage/modular_blocks.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
QwenImageSetTimestepsStep,
2727
QwenImageSetTimestepsWithStrengthStep,
2828
)
29-
from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
29+
from .decoders import QwenImageAfterDenoiseStep, QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
3030
from .denoise import (
3131
QwenImageControlNetDenoiseStep,
3232
QwenImageDenoiseStep,
@@ -92,6 +92,7 @@ def description(self):
9292
("set_timesteps", QwenImageSetTimestepsStep()),
9393
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
9494
("denoise", QwenImageDenoiseStep()),
95+
("after_denoise", QwenImageAfterDenoiseStep()),
9596
("decode", QwenImageDecodeStep()),
9697
]
9798
)
@@ -205,6 +206,7 @@ def description(self):
205206
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
206207
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
207208
("denoise", QwenImageInpaintDenoiseStep()),
209+
("after_denoise", QwenImageAfterDenoiseStep()),
208210
("decode", QwenImageInpaintDecodeStep()),
209211
]
210212
)
@@ -264,6 +266,7 @@ def description(self):
264266
("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
265267
("prepare_rope_inputs", QwenImageRoPEInputsStep()),
266268
("denoise", QwenImageDenoiseStep()),
269+
("after_denoise", QwenImageAfterDenoiseStep()),
267270
("decode", QwenImageDecodeStep()),
268271
]
269272
)
@@ -529,8 +532,9 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
529532
QwenImageAutoBeforeDenoiseStep,
530533
QwenImageOptionalControlNetBeforeDenoiseStep,
531534
QwenImageAutoDenoiseStep,
535+
QwenImageAfterDenoiseStep,
532536
]
533-
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
537+
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "after_denoise"]
534538

535539
@property
536540
def description(self):
@@ -653,6 +657,7 @@ def description(self):
653657
("set_timesteps", QwenImageSetTimestepsStep()),
654658
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
655659
("denoise", QwenImageEditDenoiseStep()),
660+
("after_denoise", QwenImageAfterDenoiseStep()),
656661
("decode", QwenImageDecodeStep()),
657662
]
658663
)
@@ -702,6 +707,7 @@ def description(self) -> str:
702707
("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
703708
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
704709
("denoise", QwenImageEditInpaintDenoiseStep()),
710+
("after_denoise", QwenImageAfterDenoiseStep()),
705711
("decode", QwenImageInpaintDecodeStep()),
706712
]
707713
)
@@ -841,8 +847,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
841847
QwenImageEditAutoInputStep,
842848
QwenImageEditAutoBeforeDenoiseStep,
843849
QwenImageEditAutoDenoiseStep,
850+
QwenImageAfterDenoiseStep,
844851
]
845-
block_names = ["input", "before_denoise", "denoise"]
852+
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
846853

847854
@property
848855
def description(self):
@@ -954,6 +961,7 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
954961
("set_timesteps", QwenImageSetTimestepsStep()),
955962
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
956963
("denoise", QwenImageEditDenoiseStep()),
964+
("after_denoise", QwenImageAfterDenoiseStep()),
957965
("decode", QwenImageDecodeStep()),
958966
]
959967
)
@@ -1037,8 +1045,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
10371045
QwenImageEditPlusAutoInputStep,
10381046
QwenImageEditPlusAutoBeforeDenoiseStep,
10391047
QwenImageEditAutoDenoiseStep,
1048+
QwenImageAfterDenoiseStep,
10401049
]
1041-
block_names = ["input", "before_denoise", "denoise"]
1050+
block_names = ["input", "before_denoise", "denoise", "after_denoise"]
10421051

10431052
@property
10441053
def description(self):

0 commit comments

Comments
 (0)