Skip to content

Commit d16b7b9

Browse files
committed
refactor!
1 parent 8dce330 commit d16b7b9

File tree

5 files changed

+336
-209
lines changed

5 files changed

+336
-209
lines changed

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 157 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
2222
from ...schedulers import FlowMatchEulerDiscreteScheduler
2323
from ...utils.torch_utils import randn_tensor, unwrap_module
24-
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
24+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks
2525
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2626
from .modular_pipeline import QwenImageModularPipeline
2727

@@ -121,6 +121,93 @@ def pack_latents(latents, batch_size, num_channels_latents, height, width):
121121
return latents
122122

123123

124+
# Prepare Latents steps
125+
126+
127+
class QwenImagePackLatentsDynamicStep(ModularPipelineBlocks):
128+
model_name = "qwenimage"
129+
130+
@property
131+
def description(self) -> str:
132+
return "Step that patchifies latents and expands batch dimension. Works with outputs from QwenImageVaeEncoderDynamicStep."
133+
134+
@property
135+
def inputs(self) -> List[InputParam]:
136+
additional_inputs = []
137+
for input_name in self._latents_input_names:
138+
additional_inputs.append(InputParam(name=input_name))
139+
140+
return [
141+
InputParam(name="num_images_per_prompt", default=1),
142+
InputParam(
143+
name="batch_size",
144+
required=True,
145+
type_hint=int,
146+
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
147+
),
148+
] + additional_inputs
149+
150+
def __init__(self, input_names: List[str] = ["image_latents"]):
151+
"""Initialize a dynamic latents packing step.
152+
153+
Args:
154+
input_names (List[str], optional): Names of latent tensors to patchify and expand.
155+
Can be a single string or list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"],
156+
["control_image_latents"]
157+
158+
"""
159+
if not isinstance(input_names, list):
160+
input_names = [input_names]
161+
self._latents_input_names = input_names
162+
super().__init__()
163+
164+
@staticmethod
165+
def check_input_shape(latents_input, latents_input_name, batch_size):
166+
if latents_input is not None and latents_input.shape[0] != 1 and latents_input.shape[0] != batch_size:
167+
raise ValueError(
168+
f"`{latents_input_name}` must have have batch size 1 or {batch_size}, but got {latents_input.shape[0]}"
169+
)
170+
171+
if latents_input.ndim != 5 and latents_input.ndim != 4:
172+
raise ValueError(f"`{latents_input_name}` must have 4 or 5 dimensions, but got {latents_input.ndim}")
173+
174+
@torch.no_grad()
175+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
176+
block_state = self.get_block_state(state)
177+
178+
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
179+
180+
for input_name in self._latents_input_names:
181+
latents_input = getattr(block_state, input_name)
182+
183+
self.check_input_shape(
184+
latents_input=latents_input,
185+
latents_input_name=input_name,
186+
batch_size=block_state.batch_size,
187+
)
188+
189+
if latents_input.ndim == 4:
190+
latents_input = latents_input.unsqueeze(2)
191+
192+
latents_input = latents_input.repeat(final_batch_size // latents_input.shape[0], 1, 1, 1, 1)
193+
194+
height_latents, width_latents = latents_input.shape[3:]
195+
196+
latents_input = pack_latents(
197+
latents=latents_input,
198+
batch_size=latents_input.shape[0],
199+
num_channels_latents=components.num_channels_latents,
200+
height=height_latents,
201+
width=width_latents,
202+
)
203+
204+
setattr(block_state, input_name, latents_input)
205+
206+
self.set_block_state(state, block_state)
207+
208+
return components, state
209+
210+
124211
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
125212
model_name = "qwenimage"
126213

@@ -233,7 +320,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
233320

234321
@property
235322
def description(self) -> str:
236-
return "Step that add noise to the image latents for the image-to-image/inpainting process. Should be run after prepare latents step."
323+
return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
237324

238325
@property
239326
def expected_components(self) -> List[ComponentSpec]:
@@ -326,7 +413,7 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
326413

327414
@property
328415
def description(self) -> str:
329-
return "Step that create the mask latents for the inpainting process. Should be run with the pachify latents step."
416+
return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space. Output is not patchified."
330417

331418
@property
332419
def inputs(self) -> List[InputParam]:
@@ -376,80 +463,57 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
376463
return components, state
377464

378465

379-
class QwenImagePackLatentsDynamicStep(ModularPipelineBlocks):
466+
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
380467
model_name = "qwenimage"
468+
"""This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:
469+
- Patchify the image latents.
470+
- Add noise to the image latents to create the `latents` input for the denoiser.
471+
- Create the latents `mask` based on the processed `mask_image`.
472+
- Patchify the `mask` to match the shape of the image latents.
473+
474+
Components:
475+
scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=]
476+
477+
Inputs:
478+
height (`None`, optional): width (`None`, optional): num_images_per_prompt (`None`, optional, defaults to 1):
479+
batch_size (`int`):
480+
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
481+
be generated in input step.
482+
image_latents (`None`, optional): latents (`Tensor`):
483+
The initial random noised, can be generated in prepare latent step.
484+
timesteps (`Tensor`):
485+
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
486+
mask_image (`Tensor`):
487+
The mask image to use for the inpainting process.
488+
489+
Outputs:
490+
init_noise (`Tensor`):
491+
The initial random noised used for inpainting denoising.
492+
mask (`Tensor`):
493+
The mask latents to use for the inpainting process.
494+
"""
381495

382-
@property
383-
def description(self) -> str:
384-
return "Step that pachify the latents inputs. Should be used with outputs from vae encoder step."
385-
386-
@property
387-
def inputs(self) -> List[InputParam]:
388-
additional_inputs = []
389-
for input_name in self._latents_input_names:
390-
additional_inputs.append(InputParam(name=input_name))
391-
392-
return [
393-
InputParam(name="num_images_per_prompt", default=1),
394-
InputParam(
395-
name="batch_size",
396-
required=True,
397-
type_hint=int,
398-
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
399-
),
400-
] + additional_inputs
401-
402-
def __init__(self, input_names: List[str] = ["image_latents"]):
403-
if not isinstance(input_names, list):
404-
input_names = [input_names]
405-
self._latents_input_names = input_names
406-
super().__init__()
407-
408-
@staticmethod
409-
def check_input_shape(latents_input, latents_input_name, batch_size):
410-
if latents_input is not None and latents_input.shape[0] != 1 and latents_input.shape[0] != batch_size:
411-
raise ValueError(
412-
f"`{latents_input_name}` must have have batch size 1 or {batch_size}, but got {latents_input.shape[0]}"
413-
)
414-
415-
if latents_input.ndim != 5 and latents_input.ndim != 4:
416-
raise ValueError(f"`{latents_input_name}` must have 4 or 5 dimensions, but got {latents_input.ndim}")
417-
418-
@torch.no_grad()
419-
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
420-
block_state = self.get_block_state(state)
421-
422-
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
423-
424-
for input_name in self._latents_input_names:
425-
latents_input = getattr(block_state, input_name)
426-
427-
self.check_input_shape(
428-
latents_input=latents_input,
429-
latents_input_name=input_name,
430-
batch_size=block_state.batch_size,
431-
)
432-
433-
if latents_input.ndim == 4:
434-
latents_input = latents_input.unsqueeze(2)
435-
436-
latents_input = latents_input.repeat(final_batch_size // latents_input.shape[0], 1, 1, 1, 1)
437-
438-
height_latents, width_latents = latents_input.shape[3:]
496+
block_classes = [
497+
QwenImagePackLatentsDynamicStep("image_latents"),
498+
QwenImagePrepareLatentsWithStrengthStep,
499+
QwenImageCreateMaskLatentsStep,
500+
QwenImagePackLatentsDynamicStep("mask"),
501+
]
439502

440-
latents_input = pack_latents(
441-
latents=latents_input,
442-
batch_size=latents_input.shape[0],
443-
num_channels_latents=components.num_channels_latents,
444-
height=height_latents,
445-
width=width_latents,
446-
)
503+
block_names = ["pack_image_latents", "add_noise_to_latents", "create_mask_latents", "pack_mask"]
447504

448-
setattr(block_state, input_name, latents_input)
505+
@property
506+
def description(self) -> str:
507+
return (
508+
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
509+
" - Patchify the image latents.\n"
510+
" - Add noise to the image latents to create the latents input for the denoiser.\n"
511+
" - Create the latents `mask` based on the processedmask image.\n"
512+
" - Patchify the mask latents to match the shape of the image latents."
513+
)
449514

450-
self.set_block_state(state, block_state)
451515

452-
return components, state
516+
# Set Timesteps steps
453517

454518

455519
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
@@ -591,6 +655,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
591655
return components, state
592656

593657

658+
# other inputs for denoiser
659+
660+
## RoPE inputs for denoiser
661+
662+
594663
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
595664
model_name = "qwenimage"
596665

@@ -728,7 +797,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
728797
return components, state
729798

730799

731-
class QwenImageControlNetPrepareInputsStep(ModularPipelineBlocks):
800+
## ControlNet inputs for denoiser
801+
class QwenImageControlNetAdditionalInputsStep(ModularPipelineBlocks):
732802
model_name = "qwenimage"
733803

734804
@property
@@ -809,3 +879,18 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
809879
self.set_block_state(state, block_state)
810880

811881
return components, state
882+
883+
884+
class QwenImageControlNetInputsStep(SequentialPipelineBlocks):
885+
model_name = "qwenimage"
886+
887+
block_classes = [
888+
QwenImagePackLatentsDynamicStep("control_image_latents"), # prepare control image latents
889+
QwenImageControlNetAdditionalInputsStep, # prepare the controlnet inputs e.g. controlnet_keep, controlnet_conditioning_scale, etc.
890+
]
891+
892+
block_names = ["prepare_control_image_latent", "prepare_controlnet_inputs"]
893+
894+
@property
895+
def description(self) -> str:
896+
return "Step that prepares the controlnet inputs. Insert before the Denoise Step, after set_timesteps step."

src/diffusers/modular_pipelines/qwenimage/denoise.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
9696
return components, block_state
9797

9898

99-
class QwenImageControlNetLoopBeforeDenoiser(ModularPipelineBlocks):
99+
class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
100100
model_name = "qwenimage"
101101

102102
@property
@@ -571,7 +571,7 @@ def description(self) -> str:
571571
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
572572
block_classes = [
573573
QwenImageLoopBeforeDenoiser,
574-
QwenImageControlNetLoopBeforeDenoiser,
574+
QwenImageLoopBeforeDenoiserControlNet,
575575
QwenImageLoopDenoiser,
576576
QwenImageLoopAfterDenoiser,
577577
]
@@ -584,13 +584,45 @@ def description(self) -> str:
584584
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
585585
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
586586
" - `QwenImageLoopBeforeDenoiser`\n"
587-
" - `QwenImageControlNetLoopBeforeDenoiser`\n"
587+
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
588588
" - `QwenImageLoopDenoiser`\n"
589589
" - `QwenImageLoopAfterDenoiser`\n"
590590
"This block supports text2img tasks."
591591
)
592592

593593

594+
# composing the controlnet denoising loops
595+
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
596+
block_classes = [
597+
QwenImageLoopBeforeDenoiser,
598+
QwenImageLoopBeforeDenoiserControlNet,
599+
QwenImageLoopDenoiser,
600+
QwenImageLoopAfterDenoiser,
601+
QwenImageLoopAfterDenoiserInpaint,
602+
]
603+
block_names = [
604+
"before_denoiser",
605+
"before_denoiser_controlnet",
606+
"denoiser",
607+
"after_denoiser",
608+
"after_denoiser_inpaint",
609+
]
610+
611+
@property
612+
def description(self) -> str:
613+
return (
614+
"Denoise step that iteratively denoise the latents. \n"
615+
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
616+
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
617+
" - `QwenImageLoopBeforeDenoiser`\n"
618+
" - `QwenImageLoopBeforeDenoiserControlNet`\n"
619+
" - `QwenImageLoopDenoiser`\n"
620+
" - `QwenImageLoopAfterDenoiser`\n"
621+
" - `QwenImageLoopAfterDenoiserInpaint`\n"
622+
"This block supports inpainting tasks with controlnet."
623+
)
624+
625+
594626
# composing the denoising loops
595627
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
596628
block_classes = [

0 commit comments

Comments
 (0)