|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from typing import Dict, List, Optional, Union |
| 15 | +from typing import Dict, List, Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | import PIL |
18 | 18 | import torch |
@@ -128,6 +128,63 @@ def get_qwen_prompt_embeds_edit( |
128 | 128 | return prompt_embeds, encoder_attention_mask |
129 | 129 |
|
130 | 130 |
|
| 131 | +def get_qwen_prompt_embeds_edit_plus( |
| 132 | + text_encoder, |
| 133 | + processor, |
| 134 | + prompt: Union[str, List[str]] = None, |
| 135 | + image: Optional[Union[torch.Tensor, List[PIL.Image.Image], [PIL.Image.Image]]] = None, |
| 136 | + prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", |
| 137 | + img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", |
| 138 | + prompt_template_encode_start_idx: int = 64, |
| 139 | + device: Optional[torch.device] = None, |
| 140 | +): |
| 141 | + prompt = [prompt] if isinstance(prompt, str) else prompt |
| 142 | + if isinstance(image, list): |
| 143 | + base_img_prompt = "" |
| 144 | + for i, img in enumerate(image): |
| 145 | + base_img_prompt += img_template_encode.format(i + 1) |
| 146 | + elif image is not None: |
| 147 | + base_img_prompt = img_template_encode.format(1) |
| 148 | + else: |
| 149 | + base_img_prompt = "" |
| 150 | + |
| 151 | + template = prompt_template_encode |
| 152 | + |
| 153 | + drop_idx = prompt_template_encode_start_idx |
| 154 | + txt = [template.format(base_img_prompt + e) for e in prompt] |
| 155 | + |
| 156 | + model_inputs = processor( |
| 157 | + text=txt, |
| 158 | + images=image, |
| 159 | + padding=True, |
| 160 | + return_tensors="pt", |
| 161 | + ).to(device) |
| 162 | + |
| 163 | + outputs = text_encoder( |
| 164 | + input_ids=model_inputs.input_ids, |
| 165 | + attention_mask=model_inputs.attention_mask, |
| 166 | + pixel_values=model_inputs.pixel_values, |
| 167 | + image_grid_thw=model_inputs.image_grid_thw, |
| 168 | + output_hidden_states=True, |
| 169 | + ) |
| 170 | + |
| 171 | + hidden_states = outputs.hidden_states[-1] |
| 172 | + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) |
| 173 | + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] |
| 174 | + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] |
| 175 | + max_seq_len = max([e.size(0) for e in split_hidden_states]) |
| 176 | + prompt_embeds = torch.stack( |
| 177 | + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] |
| 178 | + ) |
| 179 | + encoder_attention_mask = torch.stack( |
| 180 | + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] |
| 181 | + ) |
| 182 | + |
| 183 | + prompt_embeds = prompt_embeds.to(device=device) |
| 184 | + |
| 185 | + return prompt_embeds, encoder_attention_mask |
| 186 | + |
| 187 | + |
131 | 188 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
132 | 189 | def retrieve_latents( |
133 | 190 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
@@ -266,6 +323,102 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): |
266 | 323 | return components, state |
267 | 324 |
|
268 | 325 |
|
| 326 | +class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): |
| 327 | + model_name = "qwenimage" |
| 328 | + |
| 329 | + def __init__( |
| 330 | + self, |
| 331 | + input_name: str = "image", |
| 332 | + output_name: str = "resized_image", |
| 333 | + vae_image_output_name: str = "resize_vae_image", |
| 334 | + ): |
| 335 | + """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. |
| 336 | +
|
| 337 | + This block resizes an input image or a list input images and exposes the resized result under configurable |
| 338 | + input and output names. Use this when you need to wire the resize step to different image fields (e.g., |
| 339 | + "image", "control_image") |
| 340 | +
|
| 341 | + Args: |
| 342 | + input_name (str, optional): Name of the image field to read from the |
| 343 | + pipeline state. Defaults to "image". |
| 344 | + output_name (str, optional): Name of the resized image field to write |
| 345 | + back to the pipeline state. Defaults to "resized_image". |
| 346 | + vae_image_output_name (str, optional): Name of the resized image field |
| 347 | + to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus |
| 348 | + resizes the input image(s) differently for the VL and the VAE. |
| 349 | + """ |
| 350 | + if not isinstance(input_name, str) or not isinstance(output_name, str): |
| 351 | + raise ValueError( |
| 352 | + f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" |
| 353 | + ) |
| 354 | + self.condition_image_size = 384 * 384 |
| 355 | + self.vae_image_size = 1024 * 1024 |
| 356 | + self._image_input_name = input_name |
| 357 | + self._resized_image_output_name = output_name |
| 358 | + self._resized_image_vae_output_name = vae_image_output_name |
| 359 | + self._image_size_output_name = "image_sizes" |
| 360 | + super().__init__() |
| 361 | + |
| 362 | + @property |
| 363 | + def description(self) -> str: |
| 364 | + return f"Image Resize step that resize the {self._image_input_name} to the target areas of {self.condition_image_size} and {self.vae_image_size} while maintaining the aspect ratio." |
| 365 | + |
| 366 | + @property |
| 367 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 368 | + return [ |
| 369 | + OutputParam( |
| 370 | + name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" |
| 371 | + ), |
| 372 | + OutputParam( |
| 373 | + name=self._resized_image_vae_output_name, |
| 374 | + type_hint=List[PIL.Image.Image], |
| 375 | + description="The resized images to be used by the VAE encoder.", |
| 376 | + ), |
| 377 | + OutputParam( |
| 378 | + name=self._image_size_output_name, |
| 379 | + type_hint=List[Tuple[int, int]], |
| 380 | + description="Sizes of images fed to the VAE encoder. To be used with RoPE.", |
| 381 | + ), |
| 382 | + ] |
| 383 | + |
| 384 | + @torch.no_grad() |
| 385 | + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): |
| 386 | + block_state = self.get_block_state(state) |
| 387 | + |
| 388 | + images = getattr(block_state, self._image_input_name) |
| 389 | + |
| 390 | + if not is_valid_image_imagelist(images): |
| 391 | + raise ValueError(f"Images must be image or list of images but are {type(images)}") |
| 392 | + |
| 393 | + if ( |
| 394 | + not isinstance(images, torch.Tensor) |
| 395 | + and isinstance(images, PIL.Image.Image) |
| 396 | + and not isinstance(images, list) |
| 397 | + ): |
| 398 | + images = [images] |
| 399 | + |
| 400 | + # TODO: revisit this when the inputs are `torch.Tensor`s |
| 401 | + image_width, image_height = images[-1].size |
| 402 | + condition_images = [] |
| 403 | + vae_image_sizes = [] |
| 404 | + vae_images = [] |
| 405 | + for img in images: |
| 406 | + image_width, image_height = img.size |
| 407 | + condition_width, condition_height, _ = calculate_dimensions( |
| 408 | + self.condition_image_size, image_width / image_height |
| 409 | + ) |
| 410 | + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height) |
| 411 | + vae_image_sizes.append((vae_width, vae_height)) |
| 412 | + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) |
| 413 | + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) |
| 414 | + |
| 415 | + setattr(block_state, self._resized_image_output_name, condition_images) |
| 416 | + setattr(block_state, self._resized_image_vae_output_name, vae_images) |
| 417 | + setattr(block_state, self._image_size_output_name, vae_image_sizes) |
| 418 | + self.set_block_state(state, block_state) |
| 419 | + return components, state |
| 420 | + |
| 421 | + |
269 | 422 | class QwenImageTextEncoderStep(ModularPipelineBlocks): |
270 | 423 | model_name = "qwenimage" |
271 | 424 |
|
@@ -511,6 +664,74 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): |
511 | 664 | return components, state |
512 | 665 |
|
513 | 666 |
|
| 667 | +class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): |
| 668 | + model_name = "qwenimage" |
| 669 | + |
| 670 | + @property |
| 671 | + def description(self) -> str: |
| 672 | + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation.\n" |
| 673 | + |
| 674 | + @property |
| 675 | + def expected_configs(self) -> List[ConfigSpec]: |
| 676 | + return [ |
| 677 | + ConfigSpec( |
| 678 | + name="prompt_template_encode", |
| 679 | + default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", |
| 680 | + ), |
| 681 | + ConfigSpec( |
| 682 | + name="img_template_encode", |
| 683 | + default='img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"', |
| 684 | + ), |
| 685 | + ConfigSpec(name="prompt_template_encode_start_idx", default=64), |
| 686 | + ] |
| 687 | + |
| 688 | + @staticmethod |
| 689 | + def check_inputs(prompt, negative_prompt): |
| 690 | + if not isinstance(prompt, str) and not isinstance(prompt, list): |
| 691 | + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| 692 | + |
| 693 | + if ( |
| 694 | + negative_prompt is not None |
| 695 | + and not isinstance(negative_prompt, str) |
| 696 | + and not isinstance(negative_prompt, list) |
| 697 | + ): |
| 698 | + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") |
| 699 | + |
| 700 | + @torch.no_grad() |
| 701 | + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): |
| 702 | + block_state = self.get_block_state(state) |
| 703 | + |
| 704 | + self.check_inputs(block_state.prompt, block_state.negative_prompt) |
| 705 | + |
| 706 | + device = components._execution_device |
| 707 | + |
| 708 | + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( |
| 709 | + components.text_encoder, |
| 710 | + components.processor, |
| 711 | + prompt=block_state.prompt, |
| 712 | + image=block_state.resized_image, |
| 713 | + prompt_template_encode=components.config.prompt_template_encode, |
| 714 | + img_template_encode=components.config.img_template_encode, |
| 715 | + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, |
| 716 | + device=device, |
| 717 | + ) |
| 718 | + |
| 719 | + if components.requires_unconditional_embeds: |
| 720 | + negative_prompt = block_state.negative_prompt or " " |
| 721 | + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( |
| 722 | + components.text_encoder, |
| 723 | + components.processor, |
| 724 | + prompt=negative_prompt, |
| 725 | + image=block_state.resized_image, |
| 726 | + prompt_template_encode=components.config.prompt_template_encode, |
| 727 | + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, |
| 728 | + device=device, |
| 729 | + ) |
| 730 | + |
| 731 | + self.set_block_state(state, block_state) |
| 732 | + return components, state |
| 733 | + |
| 734 | + |
514 | 735 | class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): |
515 | 736 | model_name = "qwenimage" |
516 | 737 |
|
|
0 commit comments