Skip to content

Commit 8b4722d

Browse files
authored
Fix Qwen Edit Plus modular for multi-image input (#12601)
* try to fix qwen edit plus multi images (modular) * up * up * test * up * up
1 parent 07ea078 commit 8b4722d

File tree

5 files changed

+247
-35
lines changed

5 files changed

+247
-35
lines changed

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
610610
block_state = self.get_block_state(state)
611611

612612
# for edit, image size can be different from the target size (height/width)
613-
614613
block_state.img_shapes = [
615614
[
616615
(
@@ -640,6 +639,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
640639
return components, state
641640

642641

642+
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
643+
model_name = "qwenimage-edit-plus"
644+
645+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
646+
block_state = self.get_block_state(state)
647+
648+
vae_scale_factor = components.vae_scale_factor
649+
block_state.img_shapes = [
650+
[
651+
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
652+
*[
653+
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
654+
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
655+
],
656+
]
657+
] * block_state.batch_size
658+
659+
block_state.txt_seq_lens = (
660+
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
661+
)
662+
block_state.negative_txt_seq_lens = (
663+
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
664+
if block_state.negative_prompt_embeds_mask is not None
665+
else None
666+
)
667+
668+
self.set_block_state(state, block_state)
669+
670+
return components, state
671+
672+
643673
## ControlNet inputs for denoiser
644674
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
645675
model_name = "qwenimage"

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(
330330
output_name: str = "resized_image",
331331
vae_image_output_name: str = "vae_image",
332332
):
333-
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
333+
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
334334
335335
This block resizes an input image or a list input images and exposes the resized result under configurable
336336
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
@@ -809,9 +809,7 @@ def inputs(self) -> List[InputParam]:
809809

810810
@property
811811
def intermediate_outputs(self) -> List[OutputParam]:
812-
return [
813-
OutputParam(name="processed_image"),
814-
]
812+
return [OutputParam(name="processed_image")]
815813

816814
@staticmethod
817815
def check_inputs(height, width, vae_scale_factor):
@@ -851,7 +849,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
851849

852850
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
853851
model_name = "qwenimage-edit-plus"
854-
vae_image_size = 1024 * 1024
852+
853+
def __init__(self):
854+
self.vae_image_size = 1024 * 1024
855+
super().__init__()
855856

856857
@property
857858
def description(self) -> str:
@@ -868,6 +869,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
868869
if block_state.vae_image is None and block_state.image is None:
869870
raise ValueError("`vae_image` and `image` cannot be None at the same time")
870871

872+
vae_image_sizes = None
871873
if block_state.vae_image is None:
872874
image = block_state.image
873875
self.check_inputs(
@@ -879,12 +881,19 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
879881
image=image, height=height, width=width
880882
)
881883
else:
882-
width, height = block_state.vae_image[0].size
883-
image = block_state.vae_image
884+
# QwenImage Edit Plus can allow multiple input images with varied resolutions
885+
processed_images = []
886+
vae_image_sizes = []
887+
for img in block_state.vae_image:
888+
width, height = img.size
889+
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
890+
vae_image_sizes.append((vae_width, vae_height))
891+
processed_images.append(
892+
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
893+
)
894+
block_state.processed_image = processed_images
884895

885-
block_state.processed_image = components.image_processor.preprocess(
886-
image=image, height=height, width=width
887-
)
896+
block_state.vae_image_sizes = vae_image_sizes
888897

889898
self.set_block_state(state, block_state)
890899
return components, state
@@ -926,17 +935,12 @@ def description(self) -> str:
926935

927936
@property
928937
def expected_components(self) -> List[ComponentSpec]:
929-
components = [
930-
ComponentSpec("vae", AutoencoderKLQwenImage),
931-
]
938+
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
932939
return components
933940

934941
@property
935942
def inputs(self) -> List[InputParam]:
936-
inputs = [
937-
InputParam(self._image_input_name, required=True),
938-
InputParam("generator"),
939-
]
943+
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
940944
return inputs
941945

942946
@property
@@ -974,6 +978,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
974978
return components, state
975979

976980

981+
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
982+
model_name = "qwenimage-edit-plus"
983+
984+
@property
985+
def intermediate_outputs(self) -> List[OutputParam]:
986+
# Each reference image latent can have varied resolutions hence we return this as a list.
987+
return [
988+
OutputParam(
989+
self._image_latents_output_name,
990+
type_hint=List[torch.Tensor],
991+
description="The latents representing the reference image(s).",
992+
)
993+
]
994+
995+
@torch.no_grad()
996+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
997+
block_state = self.get_block_state(state)
998+
999+
device = components._execution_device
1000+
dtype = components.vae.dtype
1001+
1002+
image = getattr(block_state, self._image_input_name)
1003+
1004+
# Encode image into latents
1005+
image_latents = []
1006+
for img in image:
1007+
image_latents.append(
1008+
encode_vae_image(
1009+
image=img,
1010+
vae=components.vae,
1011+
generator=block_state.generator,
1012+
device=device,
1013+
dtype=dtype,
1014+
latent_channels=components.num_channels_latents,
1015+
)
1016+
)
1017+
1018+
setattr(block_state, self._image_latents_output_name, image_latents)
1019+
1020+
self.set_block_state(state, block_state)
1021+
1022+
return components, state
1023+
1024+
9771025
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
9781026
model_name = "qwenimage"
9791027

src/diffusers/modular_pipelines/qwenimage/inputs.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
224224
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
225225
model_name = "qwenimage"
226226

227-
def __init__(
228-
self,
229-
image_latent_inputs: List[str] = ["image_latents"],
230-
additional_batch_inputs: List[str] = [],
231-
):
227+
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
232228
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
233229
234230
This step handles multiple common tasks to prepare inputs for the denoising step:
@@ -372,6 +368,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
372368
return components, state
373369

374370

371+
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
372+
model_name = "qwenimage-edit-plus"
373+
374+
@property
375+
def intermediate_outputs(self) -> List[OutputParam]:
376+
return [
377+
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
378+
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
379+
]
380+
381+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
382+
block_state = self.get_block_state(state)
383+
384+
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
385+
for image_latent_input_name in self._image_latent_inputs:
386+
image_latent_tensor = getattr(block_state, image_latent_input_name)
387+
if image_latent_tensor is None:
388+
continue
389+
390+
# Each image latent can have different size in QwenImage Edit Plus.
391+
image_heights = []
392+
image_widths = []
393+
packed_image_latent_tensors = []
394+
395+
for img_latent_tensor in image_latent_tensor:
396+
# 1. Calculate height/width from latents
397+
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
398+
image_heights.append(height)
399+
image_widths.append(width)
400+
401+
# 2. Patchify the image latent tensor
402+
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
403+
404+
# 3. Expand batch size
405+
img_latent_tensor = repeat_tensor_to_batch_size(
406+
input_name=image_latent_input_name,
407+
input_tensor=img_latent_tensor,
408+
num_images_per_prompt=block_state.num_images_per_prompt,
409+
batch_size=block_state.batch_size,
410+
)
411+
packed_image_latent_tensors.append(img_latent_tensor)
412+
413+
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
414+
block_state.image_height = image_heights
415+
block_state.image_width = image_widths
416+
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
417+
418+
block_state.height = block_state.height or image_heights[-1]
419+
block_state.width = block_state.width or image_widths[-1]
420+
421+
# Process additional batch inputs (only batch expansion)
422+
for input_name in self._additional_batch_inputs:
423+
input_tensor = getattr(block_state, input_name)
424+
if input_tensor is None:
425+
continue
426+
427+
# Only expand batch size
428+
input_tensor = repeat_tensor_to_batch_size(
429+
input_name=input_name,
430+
input_tensor=input_tensor,
431+
num_images_per_prompt=block_state.num_images_per_prompt,
432+
batch_size=block_state.batch_size,
433+
)
434+
435+
setattr(block_state, input_name, input_tensor)
436+
437+
self.set_block_state(state, block_state)
438+
return components, state
439+
440+
375441
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
376442
model_name = "qwenimage"
377443

0 commit comments

Comments
 (0)