Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImagePipeline",
"QwenImageEditPipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaPAGPipeline",
Expand Down Expand Up @@ -1123,6 +1124,7 @@
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
QwenImageEditPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
Expand Down
79 changes: 40 additions & 39 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import functools
import math
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
pos_freqs = torch.cat(
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
neg_freqs = torch.cat(
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
Expand All @@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
dim=1,
)
self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
self.register_buffer("neg_freqs", neg_freqs, persistent=False)

# 是否使用 scale rope
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @sayakpaul for review of rope changes

@naykun We refactored the RoPE logic in #12061 to make it compatible with torch.compile. Would you be able to add the relevant changes for QwenEdit with the current implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest commit, I have refactored the code as described in #12061.
Please kindly review and let me know if it looks good. Local testing confirms that it correctly generates images.

self.scale_rope = scale_rope

def rope_params(self, index, dim, theta=10000):
Expand All @@ -201,47 +198,51 @@ def forward(self, video_fhw, txt_seq_lens, device):
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
frame, height, width = video_fhw
rope_key = f"{frame}_{height}_{width}"

if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
vid_freqs = self.rope_cache[rope_key]
else:
vid_freqs = self._compute_video_freqs(frame, height, width)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"

if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)

else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])

if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)

max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()


class QwenDoubleStreamAttnProcessor2_0:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@
"QwenImagePipeline",
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImageEditPipeline",
]
try:
if not is_onnx_available():
Expand Down Expand Up @@ -708,7 +709,7 @@
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
from .qwenimage import QwenImageEditPipeline, QwenImageInpaintPipeline, QwenImagePipeline
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/qwenimage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -35,6 +36,7 @@
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_qwenimage import QwenImagePipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else:
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def encode_prompt(
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
Expand Down Expand Up @@ -604,7 +607,7 @@ def __call__(
generator,
latents,
)
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Expand Down
Loading
Loading