|
21 | 21 | from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel |
22 | 22 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
23 | 23 | from ...utils.torch_utils import randn_tensor, unwrap_module |
24 | | -from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
| 24 | +from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks |
25 | 25 | from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
26 | 26 | from .modular_pipeline import QwenImageModularPipeline |
27 | 27 |
|
@@ -121,6 +121,93 @@ def pack_latents(latents, batch_size, num_channels_latents, height, width): |
121 | 121 | return latents |
122 | 122 |
|
123 | 123 |
|
| 124 | +# Prepare Latents steps |
| 125 | + |
| 126 | + |
| 127 | +class QwenImagePackLatentsDynamicStep(ModularPipelineBlocks): |
| 128 | + model_name = "qwenimage" |
| 129 | + |
| 130 | + @property |
| 131 | + def description(self) -> str: |
| 132 | + return "Step that patchifies latents and expands batch dimension. Works with outputs from QwenImageVaeEncoderDynamicStep." |
| 133 | + |
| 134 | + @property |
| 135 | + def inputs(self) -> List[InputParam]: |
| 136 | + additional_inputs = [] |
| 137 | + for input_name in self._latents_input_names: |
| 138 | + additional_inputs.append(InputParam(name=input_name)) |
| 139 | + |
| 140 | + return [ |
| 141 | + InputParam(name="num_images_per_prompt", default=1), |
| 142 | + InputParam( |
| 143 | + name="batch_size", |
| 144 | + required=True, |
| 145 | + type_hint=int, |
| 146 | + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", |
| 147 | + ), |
| 148 | + ] + additional_inputs |
| 149 | + |
| 150 | + def __init__(self, input_names: List[str] = ["image_latents"]): |
| 151 | + """Initialize a dynamic latents packing step. |
| 152 | +
|
| 153 | + Args: |
| 154 | + input_names (List[str], optional): Names of latent tensors to patchify and expand. |
| 155 | + Can be a single string or list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], |
| 156 | + ["control_image_latents"] |
| 157 | +
|
| 158 | + """ |
| 159 | + if not isinstance(input_names, list): |
| 160 | + input_names = [input_names] |
| 161 | + self._latents_input_names = input_names |
| 162 | + super().__init__() |
| 163 | + |
| 164 | + @staticmethod |
| 165 | + def check_input_shape(latents_input, latents_input_name, batch_size): |
| 166 | + if latents_input is not None and latents_input.shape[0] != 1 and latents_input.shape[0] != batch_size: |
| 167 | + raise ValueError( |
| 168 | + f"`{latents_input_name}` must have have batch size 1 or {batch_size}, but got {latents_input.shape[0]}" |
| 169 | + ) |
| 170 | + |
| 171 | + if latents_input.ndim != 5 and latents_input.ndim != 4: |
| 172 | + raise ValueError(f"`{latents_input_name}` must have 4 or 5 dimensions, but got {latents_input.ndim}") |
| 173 | + |
| 174 | + @torch.no_grad() |
| 175 | + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: |
| 176 | + block_state = self.get_block_state(state) |
| 177 | + |
| 178 | + final_batch_size = block_state.batch_size * block_state.num_images_per_prompt |
| 179 | + |
| 180 | + for input_name in self._latents_input_names: |
| 181 | + latents_input = getattr(block_state, input_name) |
| 182 | + |
| 183 | + self.check_input_shape( |
| 184 | + latents_input=latents_input, |
| 185 | + latents_input_name=input_name, |
| 186 | + batch_size=block_state.batch_size, |
| 187 | + ) |
| 188 | + |
| 189 | + if latents_input.ndim == 4: |
| 190 | + latents_input = latents_input.unsqueeze(2) |
| 191 | + |
| 192 | + latents_input = latents_input.repeat(final_batch_size // latents_input.shape[0], 1, 1, 1, 1) |
| 193 | + |
| 194 | + height_latents, width_latents = latents_input.shape[3:] |
| 195 | + |
| 196 | + latents_input = pack_latents( |
| 197 | + latents=latents_input, |
| 198 | + batch_size=latents_input.shape[0], |
| 199 | + num_channels_latents=components.num_channels_latents, |
| 200 | + height=height_latents, |
| 201 | + width=width_latents, |
| 202 | + ) |
| 203 | + |
| 204 | + setattr(block_state, input_name, latents_input) |
| 205 | + |
| 206 | + self.set_block_state(state, block_state) |
| 207 | + |
| 208 | + return components, state |
| 209 | + |
| 210 | + |
124 | 211 | class QwenImagePrepareLatentsStep(ModularPipelineBlocks): |
125 | 212 | model_name = "qwenimage" |
126 | 213 |
|
@@ -233,7 +320,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): |
233 | 320 |
|
234 | 321 | @property |
235 | 322 | def description(self) -> str: |
236 | | - return "Step that add noise to the image latents for the image-to-image/inpainting process. Should be run after prepare latents step." |
| 323 | + return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." |
237 | 324 |
|
238 | 325 | @property |
239 | 326 | def expected_components(self) -> List[ComponentSpec]: |
@@ -326,7 +413,7 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): |
326 | 413 |
|
327 | 414 | @property |
328 | 415 | def description(self) -> str: |
329 | | - return "Step that create the mask latents for the inpainting process. Should be run with the pachify latents step." |
| 416 | + return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space. Output is not patchified." |
330 | 417 |
|
331 | 418 | @property |
332 | 419 | def inputs(self) -> List[InputParam]: |
@@ -376,80 +463,57 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - |
376 | 463 | return components, state |
377 | 464 |
|
378 | 465 |
|
379 | | -class QwenImagePackLatentsDynamicStep(ModularPipelineBlocks): |
| 466 | +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): |
380 | 467 | model_name = "qwenimage" |
| 468 | + """This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: |
| 469 | + - Patchify the image latents. |
| 470 | + - Add noise to the image latents to create the `latents` input for the denoiser. |
| 471 | + - Create the latents `mask` based on the processed `mask_image`. |
| 472 | + - Patchify the `mask` to match the shape of the image latents. |
| 473 | +
|
| 474 | + Components: |
| 475 | + scheduler (`FlowMatchEulerDiscreteScheduler`) [subfolder=] |
| 476 | +
|
| 477 | + Inputs: |
| 478 | + height (`None`, optional): width (`None`, optional): num_images_per_prompt (`None`, optional, defaults to 1): |
| 479 | + batch_size (`int`): |
| 480 | + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can |
| 481 | + be generated in input step. |
| 482 | + image_latents (`None`, optional): latents (`Tensor`): |
| 483 | + The initial random noised, can be generated in prepare latent step. |
| 484 | + timesteps (`Tensor`): |
| 485 | + The timesteps to use for the denoising process. Can be generated in set_timesteps step. |
| 486 | + mask_image (`Tensor`): |
| 487 | + The mask image to use for the inpainting process. |
| 488 | +
|
| 489 | + Outputs: |
| 490 | + init_noise (`Tensor`): |
| 491 | + The initial random noised used for inpainting denoising. |
| 492 | + mask (`Tensor`): |
| 493 | + The mask latents to use for the inpainting process. |
| 494 | + """ |
381 | 495 |
|
382 | | - @property |
383 | | - def description(self) -> str: |
384 | | - return "Step that pachify the latents inputs. Should be used with outputs from vae encoder step." |
385 | | - |
386 | | - @property |
387 | | - def inputs(self) -> List[InputParam]: |
388 | | - additional_inputs = [] |
389 | | - for input_name in self._latents_input_names: |
390 | | - additional_inputs.append(InputParam(name=input_name)) |
391 | | - |
392 | | - return [ |
393 | | - InputParam(name="num_images_per_prompt", default=1), |
394 | | - InputParam( |
395 | | - name="batch_size", |
396 | | - required=True, |
397 | | - type_hint=int, |
398 | | - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", |
399 | | - ), |
400 | | - ] + additional_inputs |
401 | | - |
402 | | - def __init__(self, input_names: List[str] = ["image_latents"]): |
403 | | - if not isinstance(input_names, list): |
404 | | - input_names = [input_names] |
405 | | - self._latents_input_names = input_names |
406 | | - super().__init__() |
407 | | - |
408 | | - @staticmethod |
409 | | - def check_input_shape(latents_input, latents_input_name, batch_size): |
410 | | - if latents_input is not None and latents_input.shape[0] != 1 and latents_input.shape[0] != batch_size: |
411 | | - raise ValueError( |
412 | | - f"`{latents_input_name}` must have have batch size 1 or {batch_size}, but got {latents_input.shape[0]}" |
413 | | - ) |
414 | | - |
415 | | - if latents_input.ndim != 5 and latents_input.ndim != 4: |
416 | | - raise ValueError(f"`{latents_input_name}` must have 4 or 5 dimensions, but got {latents_input.ndim}") |
417 | | - |
418 | | - @torch.no_grad() |
419 | | - def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: |
420 | | - block_state = self.get_block_state(state) |
421 | | - |
422 | | - final_batch_size = block_state.batch_size * block_state.num_images_per_prompt |
423 | | - |
424 | | - for input_name in self._latents_input_names: |
425 | | - latents_input = getattr(block_state, input_name) |
426 | | - |
427 | | - self.check_input_shape( |
428 | | - latents_input=latents_input, |
429 | | - latents_input_name=input_name, |
430 | | - batch_size=block_state.batch_size, |
431 | | - ) |
432 | | - |
433 | | - if latents_input.ndim == 4: |
434 | | - latents_input = latents_input.unsqueeze(2) |
435 | | - |
436 | | - latents_input = latents_input.repeat(final_batch_size // latents_input.shape[0], 1, 1, 1, 1) |
437 | | - |
438 | | - height_latents, width_latents = latents_input.shape[3:] |
| 496 | + block_classes = [ |
| 497 | + QwenImagePackLatentsDynamicStep("image_latents"), |
| 498 | + QwenImagePrepareLatentsWithStrengthStep, |
| 499 | + QwenImageCreateMaskLatentsStep, |
| 500 | + QwenImagePackLatentsDynamicStep("mask"), |
| 501 | + ] |
439 | 502 |
|
440 | | - latents_input = pack_latents( |
441 | | - latents=latents_input, |
442 | | - batch_size=latents_input.shape[0], |
443 | | - num_channels_latents=components.num_channels_latents, |
444 | | - height=height_latents, |
445 | | - width=width_latents, |
446 | | - ) |
| 503 | + block_names = ["pack_image_latents", "add_noise_to_latents", "create_mask_latents", "pack_mask"] |
447 | 504 |
|
448 | | - setattr(block_state, input_name, latents_input) |
| 505 | + @property |
| 506 | + def description(self) -> str: |
| 507 | + return ( |
| 508 | + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" |
| 509 | + " - Patchify the image latents.\n" |
| 510 | + " - Add noise to the image latents to create the latents input for the denoiser.\n" |
| 511 | + " - Create the latents `mask` based on the processedmask image.\n" |
| 512 | + " - Patchify the mask latents to match the shape of the image latents." |
| 513 | + ) |
449 | 514 |
|
450 | | - self.set_block_state(state, block_state) |
451 | 515 |
|
452 | | - return components, state |
| 516 | +# Set Timesteps steps |
453 | 517 |
|
454 | 518 |
|
455 | 519 | class QwenImageSetTimestepsStep(ModularPipelineBlocks): |
@@ -591,6 +655,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - |
591 | 655 | return components, state |
592 | 656 |
|
593 | 657 |
|
| 658 | +# other inputs for denoiser |
| 659 | + |
| 660 | +## RoPE inputs for denoiser |
| 661 | + |
| 662 | + |
594 | 663 | class QwenImageRoPEInputsStep(ModularPipelineBlocks): |
595 | 664 | model_name = "qwenimage" |
596 | 665 |
|
@@ -728,7 +797,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - |
728 | 797 | return components, state |
729 | 798 |
|
730 | 799 |
|
731 | | -class QwenImageControlNetPrepareInputsStep(ModularPipelineBlocks): |
| 800 | +## ControlNet inputs for denoiser |
| 801 | +class QwenImageControlNetAdditionalInputsStep(ModularPipelineBlocks): |
732 | 802 | model_name = "qwenimage" |
733 | 803 |
|
734 | 804 | @property |
@@ -809,3 +879,18 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - |
809 | 879 | self.set_block_state(state, block_state) |
810 | 880 |
|
811 | 881 | return components, state |
| 882 | + |
| 883 | + |
| 884 | +class QwenImageControlNetInputsStep(SequentialPipelineBlocks): |
| 885 | + model_name = "qwenimage" |
| 886 | + |
| 887 | + block_classes = [ |
| 888 | + QwenImagePackLatentsDynamicStep("control_image_latents"), # prepare control image latents |
| 889 | + QwenImageControlNetAdditionalInputsStep, # prepare the controlnet inputs e.g. controlnet_keep, controlnet_conditioning_scale, etc. |
| 890 | + ] |
| 891 | + |
| 892 | + block_names = ["prepare_control_image_latent", "prepare_controlnet_inputs"] |
| 893 | + |
| 894 | + @property |
| 895 | + def description(self) -> str: |
| 896 | + return "Step that prepares the controlnet inputs. Insert before the Denoise Step, after set_timesteps step." |
0 commit comments