Skip to content

Commit c3675d4

Browse files
authored
[core] support QwenImage Edit Plus in modular (huggingface#12416)
* up * up * up * up * up * up * remove saves * move things around a bit. * get ready.
1 parent 2b7deff commit c3675d4

File tree

10 files changed

+449
-13
lines changed

10 files changed

+449
-13
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@
390390
"QwenImageAutoBlocks",
391391
"QwenImageEditAutoBlocks",
392392
"QwenImageEditModularPipeline",
393+
"QwenImageEditPlusAutoBlocks",
394+
"QwenImageEditPlusModularPipeline",
393395
"QwenImageModularPipeline",
394396
"StableDiffusionXLAutoBlocks",
395397
"StableDiffusionXLModularPipeline",
@@ -1052,6 +1054,8 @@
10521054
QwenImageAutoBlocks,
10531055
QwenImageEditAutoBlocks,
10541056
QwenImageEditModularPipeline,
1057+
QwenImageEditPlusAutoBlocks,
1058+
QwenImageEditPlusModularPipeline,
10551059
QwenImageModularPipeline,
10561060
StableDiffusionXLAutoBlocks,
10571061
StableDiffusionXLModularPipeline,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
"QwenImageModularPipeline",
5353
"QwenImageEditModularPipeline",
5454
"QwenImageEditAutoBlocks",
55+
"QwenImageEditPlusModularPipeline",
56+
"QwenImageEditPlusAutoBlocks",
5557
]
5658
_import_structure["components_manager"] = ["ComponentsManager"]
5759

@@ -78,6 +80,8 @@
7880
QwenImageAutoBlocks,
7981
QwenImageEditAutoBlocks,
8082
QwenImageEditModularPipeline,
83+
QwenImageEditPlusAutoBlocks,
84+
QwenImageEditPlusModularPipeline,
8185
QwenImageModularPipeline,
8286
)
8387
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
("flux", "FluxModularPipeline"),
6060
("qwenimage", "QwenImageModularPipeline"),
6161
("qwenimage-edit", "QwenImageEditModularPipeline"),
62+
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
6263
]
6364
)
6465

@@ -1628,7 +1629,8 @@ def from_pretrained(
16281629
blocks = ModularPipelineBlocks.from_pretrained(
16291630
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
16301631
)
1631-
except EnvironmentError:
1632+
except EnvironmentError as e:
1633+
logger.debug(f"EnvironmentError: {e}")
16321634
blocks = None
16331635

16341636
cache_dir = kwargs.pop("cache_dir", None)

src/diffusers/modular_pipelines/qwenimage/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,20 @@
2929
"EDIT_AUTO_BLOCKS",
3030
"EDIT_BLOCKS",
3131
"EDIT_INPAINT_BLOCKS",
32+
"EDIT_PLUS_AUTO_BLOCKS",
33+
"EDIT_PLUS_BLOCKS",
3234
"IMAGE2IMAGE_BLOCKS",
3335
"INPAINT_BLOCKS",
3436
"TEXT2IMAGE_BLOCKS",
3537
"QwenImageAutoBlocks",
3638
"QwenImageEditAutoBlocks",
39+
"QwenImageEditPlusAutoBlocks",
40+
]
41+
_import_structure["modular_pipeline"] = [
42+
"QwenImageEditModularPipeline",
43+
"QwenImageEditPlusModularPipeline",
44+
"QwenImageModularPipeline",
3745
]
38-
_import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
3946

4047
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
4148
try:
@@ -54,13 +61,20 @@
5461
EDIT_AUTO_BLOCKS,
5562
EDIT_BLOCKS,
5663
EDIT_INPAINT_BLOCKS,
64+
EDIT_PLUS_AUTO_BLOCKS,
65+
EDIT_PLUS_BLOCKS,
5766
IMAGE2IMAGE_BLOCKS,
5867
INPAINT_BLOCKS,
5968
TEXT2IMAGE_BLOCKS,
6069
QwenImageAutoBlocks,
6170
QwenImageEditAutoBlocks,
71+
QwenImageEditPlusAutoBlocks,
72+
)
73+
from .modular_pipeline import (
74+
QwenImageEditModularPipeline,
75+
QwenImageEditPlusModularPipeline,
76+
QwenImageModularPipeline,
6277
)
63-
from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
6478
else:
6579
import sys
6680

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
203203
block_state.latents = components.pachifier.pack_latents(block_state.latents)
204204

205205
self.set_block_state(state, block_state)
206-
207206
return components, state
208207

209208

@@ -571,7 +570,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
571570

572571
@property
573572
def description(self) -> str:
574-
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
573+
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
575574

576575
@property
577576
def inputs(self) -> List[InputParam]:

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 229 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,61 @@ def get_qwen_prompt_embeds_edit(
128128
return prompt_embeds, encoder_attention_mask
129129

130130

131+
def get_qwen_prompt_embeds_edit_plus(
132+
text_encoder,
133+
processor,
134+
prompt: Union[str, List[str]] = None,
135+
image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None,
136+
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
137+
img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
138+
prompt_template_encode_start_idx: int = 64,
139+
device: Optional[torch.device] = None,
140+
):
141+
prompt = [prompt] if isinstance(prompt, str) else prompt
142+
if isinstance(image, list):
143+
base_img_prompt = ""
144+
for i, img in enumerate(image):
145+
base_img_prompt += img_template_encode.format(i + 1)
146+
elif image is not None:
147+
base_img_prompt = img_template_encode.format(1)
148+
else:
149+
base_img_prompt = ""
150+
151+
template = prompt_template_encode
152+
153+
drop_idx = prompt_template_encode_start_idx
154+
txt = [template.format(base_img_prompt + e) for e in prompt]
155+
156+
model_inputs = processor(
157+
text=txt,
158+
images=image,
159+
padding=True,
160+
return_tensors="pt",
161+
).to(device)
162+
outputs = text_encoder(
163+
input_ids=model_inputs.input_ids,
164+
attention_mask=model_inputs.attention_mask,
165+
pixel_values=model_inputs.pixel_values,
166+
image_grid_thw=model_inputs.image_grid_thw,
167+
output_hidden_states=True,
168+
)
169+
170+
hidden_states = outputs.hidden_states[-1]
171+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
172+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
173+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
174+
max_seq_len = max([e.size(0) for e in split_hidden_states])
175+
prompt_embeds = torch.stack(
176+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
177+
)
178+
encoder_attention_mask = torch.stack(
179+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
180+
)
181+
182+
prompt_embeds = prompt_embeds.to(device=device)
183+
return prompt_embeds, encoder_attention_mask
184+
185+
131186
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
132187
def retrieve_latents(
133188
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -266,6 +321,83 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
266321
return components, state
267322

268323

324+
class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
325+
model_name = "qwenimage"
326+
327+
def __init__(
328+
self,
329+
input_name: str = "image",
330+
output_name: str = "resized_image",
331+
vae_image_output_name: str = "vae_image",
332+
):
333+
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
334+
335+
This block resizes an input image or a list input images and exposes the resized result under configurable
336+
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
337+
"image", "control_image")
338+
339+
Args:
340+
input_name (str, optional): Name of the image field to read from the
341+
pipeline state. Defaults to "image".
342+
output_name (str, optional): Name of the resized image field to write
343+
back to the pipeline state. Defaults to "resized_image".
344+
vae_image_output_name (str, optional): Name of the image field
345+
to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
346+
processes the input image(s) differently for the VL and the VAE.
347+
"""
348+
if not isinstance(input_name, str) or not isinstance(output_name, str):
349+
raise ValueError(
350+
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
351+
)
352+
self.condition_image_size = 384 * 384
353+
self._image_input_name = input_name
354+
self._resized_image_output_name = output_name
355+
self._vae_image_output_name = vae_image_output_name
356+
super().__init__()
357+
358+
@property
359+
def intermediate_outputs(self) -> List[OutputParam]:
360+
return super().intermediate_outputs + [
361+
OutputParam(
362+
name=self._vae_image_output_name,
363+
type_hint=List[PIL.Image.Image],
364+
description="The images to be processed which will be further used by the VAE encoder.",
365+
),
366+
]
367+
368+
@torch.no_grad()
369+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
370+
block_state = self.get_block_state(state)
371+
372+
images = getattr(block_state, self._image_input_name)
373+
374+
if not is_valid_image_imagelist(images):
375+
raise ValueError(f"Images must be image or list of images but are {type(images)}")
376+
377+
if (
378+
not isinstance(images, torch.Tensor)
379+
and isinstance(images, PIL.Image.Image)
380+
and not isinstance(images, list)
381+
):
382+
images = [images]
383+
384+
# TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
385+
condition_images = []
386+
vae_images = []
387+
for img in images:
388+
image_width, image_height = img.size
389+
condition_width, condition_height, _ = calculate_dimensions(
390+
self.condition_image_size, image_width / image_height
391+
)
392+
condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
393+
vae_images.append(img)
394+
395+
setattr(block_state, self._resized_image_output_name, condition_images)
396+
setattr(block_state, self._vae_image_output_name, vae_images)
397+
self.set_block_state(state, block_state)
398+
return components, state
399+
400+
269401
class QwenImageTextEncoderStep(ModularPipelineBlocks):
270402
model_name = "qwenimage"
271403

@@ -511,6 +643,61 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
511643
return components, state
512644

513645

646+
class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
647+
model_name = "qwenimage"
648+
649+
@property
650+
def expected_configs(self) -> List[ConfigSpec]:
651+
return [
652+
ConfigSpec(
653+
name="prompt_template_encode",
654+
default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
655+
),
656+
ConfigSpec(
657+
name="img_template_encode",
658+
default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
659+
),
660+
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
661+
]
662+
663+
@torch.no_grad()
664+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
665+
block_state = self.get_block_state(state)
666+
667+
self.check_inputs(block_state.prompt, block_state.negative_prompt)
668+
669+
device = components._execution_device
670+
671+
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus(
672+
components.text_encoder,
673+
components.processor,
674+
prompt=block_state.prompt,
675+
image=block_state.resized_image,
676+
prompt_template_encode=components.config.prompt_template_encode,
677+
img_template_encode=components.config.img_template_encode,
678+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
679+
device=device,
680+
)
681+
682+
if components.requires_unconditional_embeds:
683+
negative_prompt = block_state.negative_prompt or " "
684+
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
685+
get_qwen_prompt_embeds_edit_plus(
686+
components.text_encoder,
687+
components.processor,
688+
prompt=negative_prompt,
689+
image=block_state.resized_image,
690+
prompt_template_encode=components.config.prompt_template_encode,
691+
img_template_encode=components.config.img_template_encode,
692+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
693+
device=device,
694+
)
695+
)
696+
697+
self.set_block_state(state, block_state)
698+
return components, state
699+
700+
514701
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
515702
model_name = "qwenimage"
516703

@@ -612,12 +799,7 @@ def expected_components(self) -> List[ComponentSpec]:
612799

613800
@property
614801
def inputs(self) -> List[InputParam]:
615-
return [
616-
InputParam("resized_image"),
617-
InputParam("image"),
618-
InputParam("height"),
619-
InputParam("width"),
620-
]
802+
return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
621803

622804
@property
623805
def intermediate_outputs(self) -> List[OutputParam]:
@@ -661,6 +843,47 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
661843
return components, state
662844

663845

846+
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
847+
model_name = "qwenimage-edit-plus"
848+
vae_image_size = 1024 * 1024
849+
850+
@property
851+
def description(self) -> str:
852+
return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
853+
854+
@property
855+
def inputs(self) -> List[InputParam]:
856+
return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
857+
858+
@torch.no_grad()
859+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
860+
block_state = self.get_block_state(state)
861+
862+
if block_state.vae_image is None and block_state.image is None:
863+
raise ValueError("`vae_image` and `image` cannot be None at the same time")
864+
865+
if block_state.vae_image is None:
866+
image = block_state.image
867+
self.check_inputs(
868+
height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
869+
)
870+
height = block_state.height or components.default_height
871+
width = block_state.width or components.default_width
872+
block_state.processed_image = components.image_processor.preprocess(
873+
image=image, height=height, width=width
874+
)
875+
else:
876+
width, height = block_state.vae_image[0].size
877+
image = block_state.vae_image
878+
879+
block_state.processed_image = components.image_processor.preprocess(
880+
image=image, height=height, width=width
881+
)
882+
883+
self.set_block_state(state, block_state)
884+
return components, state
885+
886+
664887
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
665888
model_name = "qwenimage"
666889

@@ -738,7 +961,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
738961
dtype=dtype,
739962
latent_channels=components.num_channels_latents,
740963
)
741-
742964
setattr(block_state, self._image_latents_output_name, image_latents)
743965

744966
self.set_block_state(state, block_state)

0 commit comments

Comments
 (0)