Skip to content

Commit 4a9dbd5

Browse files
sayakpaula-r-r-o-w
andauthored
enable compilation in qwen image. (#12061)
* update * update * update * enable compilation in qwen image. * add tests --------- Co-authored-by: Aryan <[email protected]>
1 parent 630d27f commit 4a9dbd5

File tree

3 files changed

+137
-24
lines changed

3 files changed

+137
-24
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import functools
1617
import math
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163
self.axes_dim = axes_dim
163164
pos_index = torch.arange(1024)
164165
neg_index = torch.arange(1024).flip(0) * -1 - 1
165-
self.pos_freqs = torch.cat(
166+
pos_freqs = torch.cat(
166167
[
167168
self.rope_params(pos_index, self.axes_dim[0], self.theta),
168169
self.rope_params(pos_index, self.axes_dim[1], self.theta),
169170
self.rope_params(pos_index, self.axes_dim[2], self.theta),
170171
],
171172
dim=1,
172173
)
173-
self.neg_freqs = torch.cat(
174+
neg_freqs = torch.cat(
174175
[
175176
self.rope_params(neg_index, self.axes_dim[0], self.theta),
176177
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180
dim=1,
180181
)
181182
self.rope_cache = {}
183+
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184+
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
182185

183186
# 是否使用 scale rope
184187
self.scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199202
txt_length: [bs] a list of 1 integers representing the length of the text
200203
"""
201-
if self.pos_freqs.device != device:
202-
self.pos_freqs = self.pos_freqs.to(device)
203-
self.neg_freqs = self.neg_freqs.to(device)
204-
205204
if isinstance(video_fhw, list):
206205
video_fhw = video_fhw[0]
207206
frame, height, width = video_fhw
208207
rope_key = f"{frame}_{height}_{width}"
209208

210-
if rope_key not in self.rope_cache:
211-
seq_lens = frame * height * width
212-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
213-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
214-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
215-
if self.scale_rope:
216-
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
217-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
218-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
219-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
220-
221-
else:
222-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
223-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
224-
225-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
226-
self.rope_cache[rope_key] = freqs.clone().contiguous()
227-
vid_freqs = self.rope_cache[rope_key]
209+
if not torch.compiler.is_compiling():
210+
if rope_key not in self.rope_cache:
211+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212+
vid_freqs = self.rope_cache[rope_key]
213+
else:
214+
vid_freqs = self._compute_video_freqs(frame, height, width)
228215

229216
if self.scale_rope:
230217
max_vid_index = max(height // 2, width // 2)
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223

237224
return vid_freqs, txt_freqs
238225

226+
@functools.lru_cache(maxsize=None)
227+
def _compute_video_freqs(self, frame, height, width):
228+
seq_lens = frame * height * width
229+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231+
232+
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233+
if self.scale_rope:
234+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
236+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
237+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
238+
else:
239+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
240+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
241+
242+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
243+
return freqs.clone().contiguous()
244+
239245

240246
class QwenDoubleStreamAttnProcessor2_0:
241247
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488
_supports_gradient_checkpointing = True
483489
_no_split_modules = ["QwenImageTransformerBlock"]
484490
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
491+
_repeated_blocks = ["QwenImageTransformerBlock"]
485492

486493
@register_to_config
487494
def __init__(

tests/models/test_modeling_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,11 @@ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5
17111711
if not self.model_class._supports_group_offloading:
17121712
pytest.skip("Model does not support group offloading.")
17131713

1714+
if self.model_class.__name__ == "QwenImageTransformer2DModel":
1715+
pytest.skip(
1716+
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
1717+
)
1718+
17141719
def _has_generator_arg(model):
17151720
sig = inspect.signature(model.forward)
17161721
params = sig.parameters
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import QwenImageTransformer2DModel
21+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
22+
23+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
24+
25+
26+
enable_full_determinism()
27+
28+
29+
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
30+
model_class = QwenImageTransformer2DModel
31+
main_input_name = "hidden_states"
32+
# We override the items here because the transformer under consideration is small.
33+
model_split_percents = [0.7, 0.6, 0.6]
34+
35+
# Skip setting testing with default: AttnProcessor
36+
uses_custom_attn_processor = True
37+
38+
@property
39+
def dummy_input(self):
40+
return self.prepare_dummy_input()
41+
42+
@property
43+
def input_shape(self):
44+
return (16, 16)
45+
46+
@property
47+
def output_shape(self):
48+
return (16, 16)
49+
50+
def prepare_dummy_input(self, height=4, width=4):
51+
batch_size = 1
52+
num_latent_channels = embedding_dim = 16
53+
sequence_length = 7
54+
vae_scale_factor = 4
55+
56+
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
57+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
58+
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
59+
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
60+
orig_height = height * 2 * vae_scale_factor
61+
orig_width = width * 2 * vae_scale_factor
62+
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
63+
64+
return {
65+
"hidden_states": hidden_states,
66+
"encoder_hidden_states": encoder_hidden_states,
67+
"encoder_hidden_states_mask": encoder_hidden_states_mask,
68+
"timestep": timestep,
69+
"img_shapes": img_shapes,
70+
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
71+
}
72+
73+
def prepare_init_args_and_inputs_for_common(self):
74+
init_dict = {
75+
"patch_size": 2,
76+
"in_channels": 16,
77+
"out_channels": 4,
78+
"num_layers": 2,
79+
"attention_head_dim": 16,
80+
"num_attention_heads": 3,
81+
"joint_attention_dim": 16,
82+
"guidance_embeds": False,
83+
"axes_dims_rope": (8, 4, 4),
84+
}
85+
86+
inputs_dict = self.dummy_input
87+
return init_dict, inputs_dict
88+
89+
def test_gradient_checkpointing_is_applied(self):
90+
expected_set = {"QwenImageTransformer2DModel"}
91+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
92+
93+
94+
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
95+
model_class = QwenImageTransformer2DModel
96+
97+
def prepare_init_args_and_inputs_for_common(self):
98+
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
99+
100+
def prepare_dummy_input(self, height, width):
101+
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)

0 commit comments

Comments
 (0)