Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Optional, Union

import torch

from ...utils import deprecate


class QwenImageMixin:
@property
def guidance_scale(self):
return self._guidance_scale

@property
def attention_kwargs(self):
return self._attention_kwargs

@property
def num_timesteps(self):
return self._num_timesteps

@property
def current_timestep(self):
return self._current_timestep

@property
def interrupt(self):
return self._interrupt

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()

def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)

return split_result

@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents

@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)

return latents


class QwenImagePipelineMixin(QwenImageMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add additional Mixins here such as LoRA

def encode_prompt(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only differing encode_prompt() to not introduce any breaking change.

self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
):
r"""

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_mask (`torch.Tensor`, *optional*): Pre-generation masks.
max_sequence_length (`int`): Maximum sequence length to use.
"""
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

return prompt_embeds, prompt_embeds_mask

def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype

prompt = [prompt] if isinstance(prompt, str) else prompt

template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

return prompt_embeds, encoder_attention_mask


class QwenImageEditPipelineMixin(QwenImageMixin):
def encode_prompt(
self,
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
):
r"""

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
"""
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

return prompt_embeds, prompt_embeds_mask

def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype

prompt = [prompt] if isinstance(prompt, str) else prompt

template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]

model_inputs = self.processor(
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)

outputs = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
)

hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

return prompt_embeds, encoder_attention_mask


def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio

width = round(width / 32) * 32
height = round(height / 32) * 32

return width, height, None
Loading
Loading