Skip to content

Commit 74c91a0

Browse files
committed
feat(qwen-image):
add qwen-image-edit support
1 parent a58a4f6 commit 74c91a0

File tree

8 files changed

+938
-43
lines changed

8 files changed

+938
-43
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@
492492
"QwenImageImg2ImgPipeline",
493493
"QwenImageInpaintPipeline",
494494
"QwenImagePipeline",
495+
"QwenImageEditPipeline",
495496
"ReduxImageEncoder",
496497
"SanaControlNetPipeline",
497498
"SanaPAGPipeline",
@@ -1123,6 +1124,7 @@
11231124
PixArtAlphaPipeline,
11241125
PixArtSigmaPAGPipeline,
11251126
PixArtSigmaPipeline,
1127+
QwenImageEditPipeline,
11261128
QwenImageImg2ImgPipeline,
11271129
QwenImageInpaintPipeline,
11281130
QwenImagePipeline,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
import functools
1716
import math
1817
from typing import Any, Dict, List, Optional, Tuple, Union
1918

@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160
super().__init__()
162161
self.theta = theta
163162
self.axes_dim = axes_dim
164-
pos_index = torch.arange(1024)
165-
neg_index = torch.arange(1024).flip(0) * -1 - 1
166-
pos_freqs = torch.cat(
163+
pos_index = torch.arange(4096)
164+
neg_index = torch.arange(4096).flip(0) * -1 - 1
165+
self.pos_freqs = torch.cat(
167166
[
168167
self.rope_params(pos_index, self.axes_dim[0], self.theta),
169168
self.rope_params(pos_index, self.axes_dim[1], self.theta),
170169
self.rope_params(pos_index, self.axes_dim[2], self.theta),
171170
],
172171
dim=1,
173172
)
174-
neg_freqs = torch.cat(
173+
self.neg_freqs = torch.cat(
175174
[
176175
self.rope_params(neg_index, self.axes_dim[0], self.theta),
177176
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179
dim=1,
181180
)
182181
self.rope_cache = {}
183-
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184-
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
185182

186-
# 是否使用 scale rope
183+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184
self.scale_rope = scale_rope
188185

189186
def rope_params(self, index, dim, theta=10000):
@@ -201,47 +198,51 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199
txt_length: [bs] a list of 1 integers representing the length of the text
203200
"""
201+
if self.pos_freqs.device != device:
202+
self.pos_freqs = self.pos_freqs.to(device)
203+
self.neg_freqs = self.neg_freqs.to(device)
204+
204205
if isinstance(video_fhw, list):
205206
video_fhw = video_fhw[0]
206-
frame, height, width = video_fhw
207-
rope_key = f"{frame}_{height}_{width}"
208207

209-
if not torch.compiler.is_compiling():
210-
if rope_key not in self.rope_cache:
211-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212-
vid_freqs = self.rope_cache[rope_key]
213-
else:
214-
vid_freqs = self._compute_video_freqs(frame, height, width)
208+
vid_freqs = []
209+
max_vid_index = 0
210+
for idx, fhw in enumerate(video_fhw):
211+
frame, height, width = fhw
212+
rope_key = f"{idx}_{height}_{width}"
215213

216-
if self.scale_rope:
217-
max_vid_index = max(height // 2, width // 2)
218-
else:
219-
max_vid_index = max(height, width)
214+
if rope_key not in self.rope_cache:
215+
seq_lens = frame * height * width
216+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
217+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
218+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
219+
if self.scale_rope:
220+
freqs_height = torch.cat(
221+
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
222+
)
223+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
224+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
225+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
226+
227+
else:
228+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
229+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
230+
231+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
232+
self.rope_cache[rope_key] = freqs.clone().contiguous()
233+
vid_freqs.append(self.rope_cache[rope_key])
234+
235+
if self.scale_rope:
236+
max_vid_index = max(height // 2, width // 2, max_vid_index)
237+
else:
238+
max_vid_index = max(height, width, max_vid_index)
220239

221240
max_len = max(txt_seq_lens)
222241
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
242+
vid_freqs = torch.cat(vid_freqs, dim=0)
223243

224244
return vid_freqs, txt_freqs
225245

226-
@functools.lru_cache(maxsize=None)
227-
def _compute_video_freqs(self, frame, height, width):
228-
seq_lens = frame * height * width
229-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231-
232-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233-
if self.scale_rope:
234-
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
236-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
237-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
238-
else:
239-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
240-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
241-
242-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
243-
return freqs.clone().contiguous()
244-
245246

246247
class QwenDoubleStreamAttnProcessor2_0:
247248
"""

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@
391391
"QwenImagePipeline",
392392
"QwenImageImg2ImgPipeline",
393393
"QwenImageInpaintPipeline",
394+
"QwenImageEditPipeline",
394395
]
395396
try:
396397
if not is_onnx_available():
@@ -708,7 +709,7 @@
708709
from .paint_by_example import PaintByExamplePipeline
709710
from .pia import PIAPipeline
710711
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
711-
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712+
from .qwenimage import QwenImageEditPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712713
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
713714
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
714715
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2828
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2930

3031
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3132
try:
@@ -35,6 +36,7 @@
3536
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3637
else:
3738
from .pipeline_qwenimage import QwenImagePipeline
39+
from .pipeline_qwenimage_edit import QwenImageEditPipeline
3840
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
3941
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4042
else:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def encode_prompt(
253253
if prompt_embeds is None:
254254
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255255

256+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
257+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
258+
256259
_, seq_len, _ = prompt_embeds.shape
257260
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
258261
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -604,7 +607,7 @@ def __call__(
604607
generator,
605608
latents,
606609
)
607-
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
610+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
608611

609612
# 5. Prepare timesteps
610613
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

0 commit comments

Comments
 (0)