|  | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
| 15 | 15 | import inspect | 
| 16 |  | -import math | 
| 17 | 16 | from typing import Any, Callable, Dict, List, Optional, Union | 
| 18 | 17 | 
 | 
| 19 | 18 | import numpy as np | 
|  | 
| 28 | 27 | from ...utils.torch_utils import randn_tensor | 
| 29 | 28 | from ..pipeline_utils import DiffusionPipeline | 
| 30 | 29 | from .pipeline_output import QwenImagePipelineOutput | 
|  | 30 | +from .pipeline_qwen_utils import QwenImageEditPlusPipelineMixin, calculate_dimensions | 
| 31 | 31 | 
 | 
| 32 | 32 | 
 | 
| 33 | 33 | if is_torch_xla_available(): | 
| @@ -155,17 +155,7 @@ def retrieve_latents( | 
| 155 | 155 |         raise AttributeError("Could not access latents of provided encoder_output") | 
| 156 | 156 | 
 | 
| 157 | 157 | 
 | 
| 158 |  | -def calculate_dimensions(target_area, ratio): | 
| 159 |  | -    width = math.sqrt(target_area * ratio) | 
| 160 |  | -    height = width / ratio | 
| 161 |  | - | 
| 162 |  | -    width = round(width / 32) * 32 | 
| 163 |  | -    height = round(height / 32) * 32 | 
| 164 |  | - | 
| 165 |  | -    return width, height | 
| 166 |  | - | 
| 167 |  | - | 
| 168 |  | -class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): | 
|  | 158 | +class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageEditPlusPipelineMixin, QwenImageLoraLoaderMixin): | 
| 169 | 159 |     r""" | 
| 170 | 160 |     The Qwen-Image-Edit pipeline for image editing. | 
| 171 | 161 | 
 | 
| @@ -217,114 +207,6 @@ def __init__( | 
| 217 | 207 |         self.prompt_template_encode_start_idx = 64 | 
| 218 | 208 |         self.default_sample_size = 128 | 
| 219 | 209 | 
 | 
| 220 |  | -    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden | 
| 221 |  | -    def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | 
| 222 |  | -        bool_mask = mask.bool() | 
| 223 |  | -        valid_lengths = bool_mask.sum(dim=1) | 
| 224 |  | -        selected = hidden_states[bool_mask] | 
| 225 |  | -        split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | 
| 226 |  | - | 
| 227 |  | -        return split_result | 
| 228 |  | - | 
| 229 |  | -    def _get_qwen_prompt_embeds( | 
| 230 |  | -        self, | 
| 231 |  | -        prompt: Union[str, List[str]] = None, | 
| 232 |  | -        image: Optional[torch.Tensor] = None, | 
| 233 |  | -        device: Optional[torch.device] = None, | 
| 234 |  | -        dtype: Optional[torch.dtype] = None, | 
| 235 |  | -    ): | 
| 236 |  | -        device = device or self._execution_device | 
| 237 |  | -        dtype = dtype or self.text_encoder.dtype | 
| 238 |  | - | 
| 239 |  | -        prompt = [prompt] if isinstance(prompt, str) else prompt | 
| 240 |  | -        img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" | 
| 241 |  | -        if isinstance(image, list): | 
| 242 |  | -            base_img_prompt = "" | 
| 243 |  | -            for i, img in enumerate(image): | 
| 244 |  | -                base_img_prompt += img_prompt_template.format(i + 1) | 
| 245 |  | -        elif image is not None: | 
| 246 |  | -            base_img_prompt = img_prompt_template.format(1) | 
| 247 |  | -        else: | 
| 248 |  | -            base_img_prompt = "" | 
| 249 |  | - | 
| 250 |  | -        template = self.prompt_template_encode | 
| 251 |  | - | 
| 252 |  | -        drop_idx = self.prompt_template_encode_start_idx | 
| 253 |  | -        txt = [template.format(base_img_prompt + e) for e in prompt] | 
| 254 |  | - | 
| 255 |  | -        model_inputs = self.processor( | 
| 256 |  | -            text=txt, | 
| 257 |  | -            images=image, | 
| 258 |  | -            padding=True, | 
| 259 |  | -            return_tensors="pt", | 
| 260 |  | -        ).to(device) | 
| 261 |  | - | 
| 262 |  | -        outputs = self.text_encoder( | 
| 263 |  | -            input_ids=model_inputs.input_ids, | 
| 264 |  | -            attention_mask=model_inputs.attention_mask, | 
| 265 |  | -            pixel_values=model_inputs.pixel_values, | 
| 266 |  | -            image_grid_thw=model_inputs.image_grid_thw, | 
| 267 |  | -            output_hidden_states=True, | 
| 268 |  | -        ) | 
| 269 |  | - | 
| 270 |  | -        hidden_states = outputs.hidden_states[-1] | 
| 271 |  | -        split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) | 
| 272 |  | -        split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | 
| 273 |  | -        attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] | 
| 274 |  | -        max_seq_len = max([e.size(0) for e in split_hidden_states]) | 
| 275 |  | -        prompt_embeds = torch.stack( | 
| 276 |  | -            [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] | 
| 277 |  | -        ) | 
| 278 |  | -        encoder_attention_mask = torch.stack( | 
| 279 |  | -            [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] | 
| 280 |  | -        ) | 
| 281 |  | - | 
| 282 |  | -        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | 
| 283 |  | - | 
| 284 |  | -        return prompt_embeds, encoder_attention_mask | 
| 285 |  | - | 
| 286 |  | -    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt | 
| 287 |  | -    def encode_prompt( | 
| 288 |  | -        self, | 
| 289 |  | -        prompt: Union[str, List[str]], | 
| 290 |  | -        image: Optional[torch.Tensor] = None, | 
| 291 |  | -        device: Optional[torch.device] = None, | 
| 292 |  | -        num_images_per_prompt: int = 1, | 
| 293 |  | -        prompt_embeds: Optional[torch.Tensor] = None, | 
| 294 |  | -        prompt_embeds_mask: Optional[torch.Tensor] = None, | 
| 295 |  | -        max_sequence_length: int = 1024, | 
| 296 |  | -    ): | 
| 297 |  | -        r""" | 
| 298 |  | -
 | 
| 299 |  | -        Args: | 
| 300 |  | -            prompt (`str` or `List[str]`, *optional*): | 
| 301 |  | -                prompt to be encoded | 
| 302 |  | -            image (`torch.Tensor`, *optional*): | 
| 303 |  | -                image to be encoded | 
| 304 |  | -            device: (`torch.device`): | 
| 305 |  | -                torch device | 
| 306 |  | -            num_images_per_prompt (`int`): | 
| 307 |  | -                number of images that should be generated per prompt | 
| 308 |  | -            prompt_embeds (`torch.Tensor`, *optional*): | 
| 309 |  | -                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | 
| 310 |  | -                provided, text embeddings will be generated from `prompt` input argument. | 
| 311 |  | -        """ | 
| 312 |  | -        device = device or self._execution_device | 
| 313 |  | - | 
| 314 |  | -        prompt = [prompt] if isinstance(prompt, str) else prompt | 
| 315 |  | -        batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] | 
| 316 |  | - | 
| 317 |  | -        if prompt_embeds is None: | 
| 318 |  | -            prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) | 
| 319 |  | - | 
| 320 |  | -        _, seq_len, _ = prompt_embeds.shape | 
| 321 |  | -        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | 
| 322 |  | -        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | 
| 323 |  | -        prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) | 
| 324 |  | -        prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) | 
| 325 |  | - | 
| 326 |  | -        return prompt_embeds, prompt_embeds_mask | 
| 327 |  | - | 
| 328 | 210 |     # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs | 
| 329 | 211 |     def check_inputs( | 
| 330 | 212 |         self, | 
| @@ -381,32 +263,6 @@ def check_inputs( | 
| 381 | 263 |         if max_sequence_length is not None and max_sequence_length > 1024: | 
| 382 | 264 |             raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") | 
| 383 | 265 | 
 | 
| 384 |  | -    @staticmethod | 
| 385 |  | -    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents | 
| 386 |  | -    def _pack_latents(latents, batch_size, num_channels_latents, height, width): | 
| 387 |  | -        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | 
| 388 |  | -        latents = latents.permute(0, 2, 4, 1, 3, 5) | 
| 389 |  | -        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | 
| 390 |  | - | 
| 391 |  | -        return latents | 
| 392 |  | - | 
| 393 |  | -    @staticmethod | 
| 394 |  | -    # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents | 
| 395 |  | -    def _unpack_latents(latents, height, width, vae_scale_factor): | 
| 396 |  | -        batch_size, num_patches, channels = latents.shape | 
| 397 |  | - | 
| 398 |  | -        # VAE applies 8x compression on images but we must also account for packing which requires | 
| 399 |  | -        # latent height and width to be divisible by 2. | 
| 400 |  | -        height = 2 * (int(height) // (vae_scale_factor * 2)) | 
| 401 |  | -        width = 2 * (int(width) // (vae_scale_factor * 2)) | 
| 402 |  | - | 
| 403 |  | -        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | 
| 404 |  | -        latents = latents.permute(0, 3, 1, 4, 2, 5) | 
| 405 |  | - | 
| 406 |  | -        latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) | 
| 407 |  | - | 
| 408 |  | -        return latents | 
| 409 |  | - | 
| 410 | 266 |     # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image | 
| 411 | 267 |     def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): | 
| 412 | 268 |         if isinstance(generator, list): | 
| @@ -492,26 +348,6 @@ def prepare_latents( | 
| 492 | 348 | 
 | 
| 493 | 349 |         return latents, image_latents | 
| 494 | 350 | 
 | 
| 495 |  | -    @property | 
| 496 |  | -    def guidance_scale(self): | 
| 497 |  | -        return self._guidance_scale | 
| 498 |  | - | 
| 499 |  | -    @property | 
| 500 |  | -    def attention_kwargs(self): | 
| 501 |  | -        return self._attention_kwargs | 
| 502 |  | - | 
| 503 |  | -    @property | 
| 504 |  | -    def num_timesteps(self): | 
| 505 |  | -        return self._num_timesteps | 
| 506 |  | - | 
| 507 |  | -    @property | 
| 508 |  | -    def current_timestep(self): | 
| 509 |  | -        return self._current_timestep | 
| 510 |  | - | 
| 511 |  | -    @property | 
| 512 |  | -    def interrupt(self): | 
| 513 |  | -        return self._interrupt | 
| 514 |  | - | 
| 515 | 351 |     @torch.no_grad() | 
| 516 | 352 |     @replace_example_docstring(EXAMPLE_DOC_STRING) | 
| 517 | 353 |     def __call__( | 
| @@ -628,7 +464,7 @@ def __call__( | 
| 628 | 464 |             returning a tuple, the first element is a list with the generated images. | 
| 629 | 465 |         """ | 
| 630 | 466 |         image_size = image[-1].size if isinstance(image, list) else image.size | 
| 631 |  | -        calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) | 
|  | 467 | +        calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) | 
| 632 | 468 |         height = height or calculated_height | 
| 633 | 469 |         width = width or calculated_width | 
| 634 | 470 | 
 | 
| @@ -674,10 +510,10 @@ def __call__( | 
| 674 | 510 |             vae_images = [] | 
| 675 | 511 |             for img in image: | 
| 676 | 512 |                 image_width, image_height = img.size | 
| 677 |  | -                condition_width, condition_height = calculate_dimensions( | 
|  | 513 | +                condition_width, condition_height, _ = calculate_dimensions( | 
| 678 | 514 |                     CONDITION_IMAGE_SIZE, image_width / image_height | 
| 679 | 515 |                 ) | 
| 680 |  | -                vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) | 
|  | 516 | +                vae_width, vae_height, _ = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) | 
| 681 | 517 |                 condition_image_sizes.append((condition_width, condition_height)) | 
| 682 | 518 |                 vae_image_sizes.append((vae_width, vae_height)) | 
| 683 | 519 |                 condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) | 
|  | 
0 commit comments