Skip to content

Commit d91aa18

Browse files
committed
up
1 parent cc5b31f commit d91aa18

File tree

5 files changed

+462
-4
lines changed

5 files changed

+462
-4
lines changed

src/diffusers/modular_pipelines/qwenimage/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@
2929
"EDIT_AUTO_BLOCKS",
3030
"EDIT_BLOCKS",
3131
"EDIT_INPAINT_BLOCKS",
32+
"EDIT_PLUS_AUTO_BLOCKS",
33+
"EDIT_PLUS_BLOCKS",
3234
"IMAGE2IMAGE_BLOCKS",
3335
"INPAINT_BLOCKS",
3436
"TEXT2IMAGE_BLOCKS",
3537
"QwenImageAutoBlocks",
3638
"QwenImageEditAutoBlocks",
39+
"QwenImageEditPlusAutoBlocks",
3740
]
3841
_import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
3942

@@ -54,13 +57,20 @@
5457
EDIT_AUTO_BLOCKS,
5558
EDIT_BLOCKS,
5659
EDIT_INPAINT_BLOCKS,
60+
EDIT_PLUS_AUTO_BLOCKS,
61+
EDIT_PLUS_BLOCKS,
5762
IMAGE2IMAGE_BLOCKS,
5863
INPAINT_BLOCKS,
5964
TEXT2IMAGE_BLOCKS,
6065
QwenImageAutoBlocks,
6166
QwenImageEditAutoBlocks,
67+
QwenImageEditPlusAutoBlocks,
68+
)
69+
from .modular_pipeline import (
70+
QwenImageEditModularPipeline,
71+
QwenImageEditPlusModularPipeline,
72+
QwenImageModularPipeline,
6273
)
63-
from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
6474
else:
6575
import sys
6676

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
571571

572572
@property
573573
def description(self) -> str:
574-
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
574+
return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
575575

576576
@property
577577
def inputs(self) -> List[InputParam]:
@@ -641,6 +641,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
641641
return components, state
642642

643643

644+
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
645+
model_name = "qwenimage"
646+
# TODO: Is there a better way to handle this name? It's used in
647+
# `QwenImageEditPlusResizeDynamicStep` as well. We can later
648+
# keep these things as a module-level constant.
649+
_image_size_output_name = "image_sizes"
650+
651+
@property
652+
def inputs(self) -> List[InputParam]:
653+
inputs_list = super().inputs
654+
return inputs_list + [
655+
InputParam(name=self._image_size_output_name, required=True),
656+
]
657+
658+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
659+
block_state = self.get_block_state(state)
660+
vae_image_sizes = getattr(block_state, self._image_size_output_name)
661+
height, width = block_state.image_height, block_state.image_width
662+
663+
# for edit, image size can be different from the target size (height/width)
664+
block_state.img_shapes = [
665+
[
666+
(1, height // components.vae_scale_factor // 2, width // components.vae_scale_factor // 2),
667+
*[
668+
(1, vae_height // components.vae_scale_factor // 2, vae_width // components.vae_scale_factor // 2)
669+
for vae_width, vae_height in vae_image_sizes
670+
],
671+
]
672+
] * block_state.batch_size
673+
674+
block_state.txt_seq_lens = (
675+
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
676+
)
677+
block_state.negative_txt_seq_lens = (
678+
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
679+
if block_state.negative_prompt_embeds_mask is not None
680+
else None
681+
)
682+
683+
self.set_block_state(state, block_state)
684+
685+
return components, state
686+
687+
644688
## ControlNet inputs for denoiser
645689
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
646690
model_name = "qwenimage"

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, Optional, Union
15+
from typing import Dict, List, Optional, Tuple, Union
1616

1717
import PIL
1818
import torch
@@ -128,6 +128,63 @@ def get_qwen_prompt_embeds_edit(
128128
return prompt_embeds, encoder_attention_mask
129129

130130

131+
def get_qwen_prompt_embeds_edit_plus(
132+
text_encoder,
133+
processor,
134+
prompt: Union[str, List[str]] = None,
135+
image: Optional[Union[torch.Tensor, List[PIL.Image.Image], [PIL.Image.Image]]] = None,
136+
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
137+
img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
138+
prompt_template_encode_start_idx: int = 64,
139+
device: Optional[torch.device] = None,
140+
):
141+
prompt = [prompt] if isinstance(prompt, str) else prompt
142+
if isinstance(image, list):
143+
base_img_prompt = ""
144+
for i, img in enumerate(image):
145+
base_img_prompt += img_template_encode.format(i + 1)
146+
elif image is not None:
147+
base_img_prompt = img_template_encode.format(1)
148+
else:
149+
base_img_prompt = ""
150+
151+
template = prompt_template_encode
152+
153+
drop_idx = prompt_template_encode_start_idx
154+
txt = [template.format(base_img_prompt + e) for e in prompt]
155+
156+
model_inputs = processor(
157+
text=txt,
158+
images=image,
159+
padding=True,
160+
return_tensors="pt",
161+
).to(device)
162+
163+
outputs = text_encoder(
164+
input_ids=model_inputs.input_ids,
165+
attention_mask=model_inputs.attention_mask,
166+
pixel_values=model_inputs.pixel_values,
167+
image_grid_thw=model_inputs.image_grid_thw,
168+
output_hidden_states=True,
169+
)
170+
171+
hidden_states = outputs.hidden_states[-1]
172+
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
173+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
174+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
175+
max_seq_len = max([e.size(0) for e in split_hidden_states])
176+
prompt_embeds = torch.stack(
177+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
178+
)
179+
encoder_attention_mask = torch.stack(
180+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
181+
)
182+
183+
prompt_embeds = prompt_embeds.to(device=device)
184+
185+
return prompt_embeds, encoder_attention_mask
186+
187+
131188
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
132189
def retrieve_latents(
133190
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -266,6 +323,102 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
266323
return components, state
267324

268325

326+
class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
327+
model_name = "qwenimage"
328+
329+
def __init__(
330+
self,
331+
input_name: str = "image",
332+
output_name: str = "resized_image",
333+
vae_image_output_name: str = "resize_vae_image",
334+
):
335+
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
336+
337+
This block resizes an input image or a list input images and exposes the resized result under configurable
338+
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
339+
"image", "control_image")
340+
341+
Args:
342+
input_name (str, optional): Name of the image field to read from the
343+
pipeline state. Defaults to "image".
344+
output_name (str, optional): Name of the resized image field to write
345+
back to the pipeline state. Defaults to "resized_image".
346+
vae_image_output_name (str, optional): Name of the resized image field
347+
to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
348+
resizes the input image(s) differently for the VL and the VAE.
349+
"""
350+
if not isinstance(input_name, str) or not isinstance(output_name, str):
351+
raise ValueError(
352+
f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
353+
)
354+
self.condition_image_size = 384 * 384
355+
self.vae_image_size = 1024 * 1024
356+
self._image_input_name = input_name
357+
self._resized_image_output_name = output_name
358+
self._resized_image_vae_output_name = vae_image_output_name
359+
self._image_size_output_name = "image_sizes"
360+
super().__init__()
361+
362+
@property
363+
def description(self) -> str:
364+
return f"Image Resize step that resize the {self._image_input_name} to the target areas of {self.condition_image_size} and {self.vae_image_size} while maintaining the aspect ratio."
365+
366+
@property
367+
def intermediate_outputs(self) -> List[OutputParam]:
368+
return [
369+
OutputParam(
370+
name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
371+
),
372+
OutputParam(
373+
name=self._resized_image_vae_output_name,
374+
type_hint=List[PIL.Image.Image],
375+
description="The resized images to be used by the VAE encoder.",
376+
),
377+
OutputParam(
378+
name=self._image_size_output_name,
379+
type_hint=List[Tuple[int, int]],
380+
description="Sizes of images fed to the VAE encoder. To be used with RoPE.",
381+
),
382+
]
383+
384+
@torch.no_grad()
385+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
386+
block_state = self.get_block_state(state)
387+
388+
images = getattr(block_state, self._image_input_name)
389+
390+
if not is_valid_image_imagelist(images):
391+
raise ValueError(f"Images must be image or list of images but are {type(images)}")
392+
393+
if (
394+
not isinstance(images, torch.Tensor)
395+
and isinstance(images, PIL.Image.Image)
396+
and not isinstance(images, list)
397+
):
398+
images = [images]
399+
400+
# TODO: revisit this when the inputs are `torch.Tensor`s
401+
image_width, image_height = images[-1].size
402+
condition_images = []
403+
vae_image_sizes = []
404+
vae_images = []
405+
for img in images:
406+
image_width, image_height = img.size
407+
condition_width, condition_height, _ = calculate_dimensions(
408+
self.condition_image_size, image_width / image_height
409+
)
410+
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height)
411+
vae_image_sizes.append((vae_width, vae_height))
412+
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
413+
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
414+
415+
setattr(block_state, self._resized_image_output_name, condition_images)
416+
setattr(block_state, self._resized_image_vae_output_name, vae_images)
417+
setattr(block_state, self._image_size_output_name, vae_image_sizes)
418+
self.set_block_state(state, block_state)
419+
return components, state
420+
421+
269422
class QwenImageTextEncoderStep(ModularPipelineBlocks):
270423
model_name = "qwenimage"
271424

@@ -511,6 +664,74 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
511664
return components, state
512665

513666

667+
class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
668+
model_name = "qwenimage"
669+
670+
@property
671+
def description(self) -> str:
672+
return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation.\n"
673+
674+
@property
675+
def expected_configs(self) -> List[ConfigSpec]:
676+
return [
677+
ConfigSpec(
678+
name="prompt_template_encode",
679+
default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
680+
),
681+
ConfigSpec(
682+
name="img_template_encode",
683+
default='img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"',
684+
),
685+
ConfigSpec(name="prompt_template_encode_start_idx", default=64),
686+
]
687+
688+
@staticmethod
689+
def check_inputs(prompt, negative_prompt):
690+
if not isinstance(prompt, str) and not isinstance(prompt, list):
691+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
692+
693+
if (
694+
negative_prompt is not None
695+
and not isinstance(negative_prompt, str)
696+
and not isinstance(negative_prompt, list)
697+
):
698+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
699+
700+
@torch.no_grad()
701+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
702+
block_state = self.get_block_state(state)
703+
704+
self.check_inputs(block_state.prompt, block_state.negative_prompt)
705+
706+
device = components._execution_device
707+
708+
block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus(
709+
components.text_encoder,
710+
components.processor,
711+
prompt=block_state.prompt,
712+
image=block_state.resized_image,
713+
prompt_template_encode=components.config.prompt_template_encode,
714+
img_template_encode=components.config.img_template_encode,
715+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
716+
device=device,
717+
)
718+
719+
if components.requires_unconditional_embeds:
720+
negative_prompt = block_state.negative_prompt or " "
721+
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
722+
components.text_encoder,
723+
components.processor,
724+
prompt=negative_prompt,
725+
image=block_state.resized_image,
726+
prompt_template_encode=components.config.prompt_template_encode,
727+
prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
728+
device=device,
729+
)
730+
731+
self.set_block_state(state, block_state)
732+
return components, state
733+
734+
514735
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
515736
model_name = "qwenimage"
516737

0 commit comments

Comments
 (0)