Skip to content

Commit f72763c

Browse files
committed
add support for qwen edit
1 parent 49e683f commit f72763c

File tree

5 files changed

+418
-104
lines changed

5 files changed

+418
-104
lines changed

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
import torch
1919
import inspect
2020

21+
from ...image_processor import VaeImageProcessor
22+
from ...configuration_utils import FrozenDict
23+
24+
from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
25+
2126
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2227
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2328
from .modular_pipeline import QwenImageModularPipeline
@@ -110,6 +115,41 @@ def pack_latents(latents, batch_size, num_channels_latents, height, width):
110115
return latents
111116

112117

118+
class QwenImageImageResizeStep(ModularPipelineBlocks):
119+
model_name = "qwenimage"
120+
121+
@property
122+
def description(self) -> str:
123+
return "Image Resize step that resize the image to the target area while maintaining the aspect ratio"
124+
125+
@property
126+
def expected_components(self) -> List[ComponentSpec]:
127+
return [
128+
ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 16}), default_creation_method="from_config"),
129+
]
130+
131+
@property
132+
def inputs(self) -> List[InputParam]:
133+
return [
134+
InputParam(name="image", required=True, type_hint=torch.Tensor, description="The image to resize"),
135+
]
136+
137+
@torch.no_grad()
138+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
139+
block_state = self.get_block_state(state)
140+
141+
142+
if not isinstance(block_state.image, list):
143+
block_state.image = [block_state.image]
144+
145+
image_width, image_height = block_state.image[0].size
146+
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
147+
148+
block_state.image = components.image_processor.resize(block_state.image, height=calculated_height, width=calculated_width)
149+
self.set_block_state(state, block_state)
150+
return components, state
151+
152+
113153
class QwenImageInputStep(ModularPipelineBlocks):
114154

115155
model_name = "qwenimage"
@@ -123,6 +163,7 @@ def description(self) -> str:
123163
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
124164
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
125165
"have a final batch_size of batch_size * num_images_per_prompt."
166+
" 3. If `image_latents` is provided and `height` and `width` are not provided, it will update the `height` and `width` parameters."
126167
)
127168

128169
@property
@@ -133,7 +174,9 @@ def inputs(self) -> List[InputParam]:
133174
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="guider_input_fields"),
134175
InputParam(name="negative_prompt_embeds", kwargs_type="guider_input_fields"),
135176
InputParam(name="negative_prompt_embeds_mask", kwargs_type="guider_input_fields"),
136-
177+
InputParam(name="image_latents"),
178+
InputParam(name="height"),
179+
InputParam(name="width"),
137180
]
138181

139182
@property
@@ -152,7 +195,7 @@ def intermediate_outputs(self) -> List[str]:
152195
]
153196

154197
@staticmethod
155-
def check_inputs(prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask):
198+
def check_inputs(prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask, image_latents):
156199

157200
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
158201
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
@@ -168,6 +211,9 @@ def check_inputs(prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, nega
168211

169212
elif negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
170213
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
214+
215+
if image_latents is not None and image_latents.shape[0] != 1 and image_latents.shape[0] != prompt_embeds.shape[0]:
216+
raise ValueError(f"`image_latents` must have have batch size 1 or {prompt_embeds.shape[0]}, but got {image_latents.shape[0]}")
171217

172218

173219

@@ -180,6 +226,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
180226
prompt_embeds_mask=block_state.prompt_embeds_mask,
181227
negative_prompt_embeds=block_state.negative_prompt_embeds,
182228
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
229+
image_latents=block_state.image_latents,
183230
)
184231

185232
block_state.batch_size = block_state.prompt_embeds.shape[0]
@@ -204,7 +251,20 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
204251
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
205252
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
206253
block_state.batch_size * block_state.num_images_per_prompt, seq_len)
207-
254+
255+
if block_state.image_latents is not None:
256+
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
257+
block_state.image_latents = block_state.image_latents.repeat(
258+
final_batch_size // block_state.image_latents.shape[0], 1, 1, 1, 1
259+
)
260+
261+
height_image_latent, width_image_latent = block_state.image_latents.shape[3:]
262+
263+
if block_state.height is None:
264+
block_state.height = height_image_latent * components.vae_scale_factor
265+
if block_state.width is None:
266+
block_state.width = width_image_latent * components.vae_scale_factor
267+
208268
self.set_block_state(state, block_state)
209269

210270
return components, state
@@ -312,6 +372,43 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
312372
return components, state
313373

314374

375+
class QwenImagePrepareImageLatentsStep(ModularPipelineBlocks):
376+
377+
model_name = "qwenimage"
378+
379+
@property
380+
def description(self) -> str:
381+
return "Prepare latents step that prepares the latents for the text-to-image generation process"
382+
383+
@property
384+
def inputs(self) -> List[InputParam]:
385+
return [
386+
InputParam(name="image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image, can be generated in vae encoder step"),
387+
]
388+
389+
390+
@torch.no_grad()
391+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
392+
393+
block_state = self.get_block_state(state)
394+
395+
height_image_latent, width_image_latent = block_state.image_latents.shape[3:]
396+
397+
block_state.image_latents = pack_latents(
398+
latents=block_state.image_latents,
399+
batch_size=block_state.image_latents.shape[0],
400+
num_channels_latents=components.num_channels_latents,
401+
height=height_image_latent,
402+
width=width_image_latent,
403+
)
404+
405+
406+
self.set_block_state(state, block_state)
407+
408+
return components, state
409+
410+
411+
315412

316413
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
317414

@@ -410,6 +507,58 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
410507
)
411508

412509

510+
self.set_block_state(state, block_state)
511+
512+
return components, state
513+
514+
515+
class QwenImageEditPrepareAdditionalInputsStep(ModularPipelineBlocks):
516+
517+
model_name = "qwenimage"
518+
519+
@property
520+
def description(self) -> str:
521+
return "Step that prepares the additional inputs for the text-to-image generation process"
522+
523+
@property
524+
def inputs(self) -> List[InputParam]:
525+
return [
526+
InputParam(name="batch_size", required=True),
527+
InputParam(name="image", required=True, type_hint=torch.Tensor, description="The resized image input"),
528+
InputParam(name="height", required=True),
529+
InputParam(name="width", required=True),
530+
InputParam(name="prompt_embeds_mask"),
531+
InputParam(name="negative_prompt_embeds_mask"),
532+
]
533+
534+
@property
535+
def intermediate_outputs(self) -> List[OutputParam]:
536+
return [
537+
OutputParam(name="img_shapes", type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation"),
538+
OutputParam(name="txt_seq_lens", kwargs_type="guider_input_fields", type_hint=List[int], description="The sequence lengths of the prompt embeds, used for RoPE calculation"),
539+
OutputParam(name="negative_txt_seq_lens", kwargs_type="guider_input_fields", type_hint=List[int], description="The sequence lengths of the negative prompt embeds, used for RoPE calculation"),
540+
]
541+
542+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
543+
544+
block_state = self.get_block_state(state)
545+
546+
image = block_state.image[0] if isinstance(block_state.image, list) else block_state.image
547+
image_width, image_height = image.size
548+
549+
block_state.img_shapes = [
550+
[
551+
(1, block_state.height // components.vae_scale_factor // 2, block_state.width // components.vae_scale_factor // 2),
552+
(1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
553+
]
554+
] * block_state.batch_size
555+
556+
block_state.txt_seq_lens = block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
557+
block_state.negative_txt_seq_lens = (
558+
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() if block_state.negative_prompt_embeds_mask is not None else None
559+
)
560+
561+
413562
self.set_block_state(state, block_state)
414563

415564
return components, state

src/diffusers/modular_pipelines/qwenimage/denoise.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,34 @@ def inputs(self) -> List[InputParam]:
5252
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
5353
# one timestep
5454
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
55+
block_state.latent_model_input = block_state.latents
56+
return components, block_state
57+
58+
59+
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
60+
model_name = "qwenimage"
61+
62+
@property
63+
def description(self) -> str:
64+
return (
65+
"step within the denoising loop that prepares the latent input for the denoiser. "
66+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
67+
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
68+
)
69+
70+
@property
71+
def inputs(self) -> List[InputParam]:
72+
return [
73+
InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."),
74+
InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step."),
75+
]
76+
77+
@torch.no_grad()
78+
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
79+
# one timestep
80+
81+
block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
82+
block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
5583
return components, block_state
5684

5785

@@ -107,7 +135,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
107135

108136
# YiYi TODO: add cache context
109137
guider_state_batch.noise_pred = components.transformer(
110-
hidden_states=block_state.latents,
138+
hidden_states=block_state.latent_model_input,
111139
timestep=block_state.timestep / 1000,
112140
img_shapes=block_state.img_shapes,
113141
attention_kwargs=block_state.attention_kwargs,
@@ -128,7 +156,80 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState
128156
return components, block_state
129157

130158

131-
159+
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
160+
model_name = "qwenimage"
161+
162+
@property
163+
def description(self) -> str:
164+
return (
165+
"step within the denoising loop that denoise the latent input for the denoiser. "
166+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
167+
"object (e.g. `QwenImageDenoiseLoopWrapper`)"
168+
)
169+
170+
@property
171+
def expected_components(self) -> List[ComponentSpec]:
172+
return [
173+
ComponentSpec(
174+
"guider",
175+
ClassifierFreeGuidance,
176+
config=FrozenDict({"guidance_scale": 4.0}),
177+
default_creation_method="from_config",
178+
),
179+
ComponentSpec("transformer", QwenImageTransformer2DModel),
180+
]
181+
182+
@property
183+
def inputs(self) -> List[InputParam]:
184+
return [
185+
InputParam("attention_kwargs"),
186+
InputParam("latents", required=True, type_hint=torch.Tensor, description="The latents to use for the denoising process. Can be generated in prepare_latents step."),
187+
InputParam("num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."),
188+
InputParam(kwargs_type="guider_input_fields", description="All coditional model inputs that need to be prepared with guider: e.g. prompt_embeds, negative_prompt_embeds, etc."),
189+
InputParam("img_shapes", required=True, type_hint=List[Tuple[int, int]], description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step."),
190+
]
191+
192+
@torch.no_grad()
193+
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
194+
195+
guider_input_fields = {
196+
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
197+
"encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
198+
"txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
199+
}
200+
201+
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
202+
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
203+
204+
for guider_state_batch in guider_state:
205+
components.guider.prepare_models(components.transformer)
206+
cond_kwargs = guider_state_batch.as_dict()
207+
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
208+
209+
# YiYi TODO: add cache context
210+
guider_state_batch.noise_pred = components.transformer(
211+
hidden_states=block_state.latent_model_input,
212+
timestep=block_state.timestep / 1000,
213+
img_shapes=block_state.img_shapes,
214+
attention_kwargs=block_state.attention_kwargs,
215+
return_dict=False,
216+
**cond_kwargs,
217+
)[0]
218+
219+
components.guider.cleanup_models(components.transformer)
220+
221+
guider_output = components.guider(guider_state)
222+
223+
pred = guider_output.pred[:, : block_state.latents.size(1)]
224+
pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
225+
226+
# apply guidance rescale
227+
pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
228+
pred_norm = torch.norm(pred, dim=-1, keepdim=True)
229+
block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
230+
231+
232+
return components, block_state
132233

133234

134235
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
@@ -237,4 +338,26 @@ def description(self) -> str:
237338
" - `QwenImageLoopDenoiser`\n"
238339
" - `QwenImageLoopAfterDenoiser`\n"
239340
"This block supports text2img tasks."
341+
)
342+
343+
344+
# composing the denoising loops
345+
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
346+
block_classes = [
347+
QwenImageEditLoopBeforeDenoiser,
348+
QwenImageEditLoopDenoiser,
349+
QwenImageLoopAfterDenoiser,
350+
]
351+
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
352+
353+
@property
354+
def description(self) -> str:
355+
return (
356+
"Denoise step that iteratively denoise the latents. \n"
357+
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
358+
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
359+
" - `QwenImageEditLoopBeforeDenoiser`\n"
360+
" - `QwenImageEditLoopDenoiser`\n"
361+
" - `QwenImageLoopAfterDenoiser`\n"
362+
"This block supports text2img and img2img tasks."
240363
)

0 commit comments

Comments
 (0)