Skip to content

Commit f644e9b

Browse files
committed
fix(qwen-image):
- compatible with attention dispatcher - cond cache support
1 parent 7c42801 commit f644e9b

File tree

2 files changed

+41
-189
lines changed

2 files changed

+41
-189
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 17 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from ..attention import FeedForward
28-
from ..attention_processor import (
29-
Attention,
30-
AttentionProcessor,
31-
)
28+
from ..attention_dispatch import dispatch_attention_fn
29+
from ..attention_processor import Attention
3230
from ..cache_utils import CacheMixin
3331
from ..embeddings import TimestepEmbedding, Timesteps
3432
from ..modeling_outputs import Transformer2DModelOutput
@@ -107,7 +105,7 @@ def apply_rotary_emb_qwen(
107105
108106
Args:
109107
x (`torch.Tensor`):
110-
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
108+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
111109
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
112110
113111
Returns:
@@ -135,6 +133,7 @@ def apply_rotary_emb_qwen(
135133
return out
136134
else:
137135
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
136+
freqs_cis = freqs_cis.unsqueeze(1)
138137
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
139138

140139
return x_out.type_as(x)
@@ -148,7 +147,6 @@ def __init__(self, embedding_dim, pooled_projection_dim):
148147
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
149148

150149
def forward(self, timestep, hidden_states):
151-
# import ipdb; ipdb.set_trace()
152150
timesteps_proj = self.time_proj(timestep)
153151
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
154152

@@ -245,6 +243,8 @@ class QwenDoubleStreamAttnProcessor2_0:
245243
implements joint attention computation where text and image streams are processed together.
246244
"""
247245

246+
_attention_backend = None
247+
248248
def __init__(self):
249249
if not hasattr(F, "scaled_dot_product_attention"):
250250
raise ImportError(
@@ -263,8 +263,6 @@ def __call__(
263263
if encoder_hidden_states is None:
264264
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
265265

266-
batch_size = hidden_states.shape[0]
267-
seq_img = hidden_states.shape[1]
268266
seq_txt = encoder_hidden_states.shape[1]
269267

270268
# Compute QKV for image stream (sample projections)
@@ -277,20 +275,14 @@ def __call__(
277275
txt_key = attn.add_k_proj(encoder_hidden_states)
278276
txt_value = attn.add_v_proj(encoder_hidden_states)
279277

280-
inner_dim = img_key.shape[-1]
281-
head_dim = inner_dim // attn.heads
282-
283278
# Reshape for multi-head attention
284-
def reshape_for_heads(tensor, seq_len):
285-
return tensor.view(batch_size, seq_len, attn.heads, head_dim).transpose(1, 2)
286-
287-
img_query = reshape_for_heads(img_query, seq_img)
288-
img_key = reshape_for_heads(img_key, seq_img)
289-
img_value = reshape_for_heads(img_value, seq_img)
279+
img_query = img_query.unflatten(-1, (attn.heads, -1))
280+
img_key = img_key.unflatten(-1, (attn.heads, -1))
281+
img_value = img_value.unflatten(-1, (attn.heads, -1))
290282

291-
txt_query = reshape_for_heads(txt_query, seq_txt)
292-
txt_key = reshape_for_heads(txt_key, seq_txt)
293-
txt_value = reshape_for_heads(txt_value, seq_txt)
283+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
284+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
285+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
294286

295287
# Apply QK normalization
296288
if attn.norm_q is not None:
@@ -307,23 +299,22 @@ def reshape_for_heads(tensor, seq_len):
307299
img_freqs, txt_freqs = image_rotary_emb
308300
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
309301
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
310-
# import ipdb; ipdb.set_trace()
311302
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
312303
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
313304

314305
# Concatenate for joint attention
315306
# Order: [text, image]
316-
joint_query = torch.cat([txt_query, img_query], dim=2)
317-
joint_key = torch.cat([txt_key, img_key], dim=2)
318-
joint_value = torch.cat([txt_value, img_value], dim=2)
307+
joint_query = torch.cat([txt_query, img_query], dim=1)
308+
joint_key = torch.cat([txt_key, img_key], dim=1)
309+
joint_value = torch.cat([txt_value, img_value], dim=1)
319310

320311
# Compute joint attention
321-
joint_hidden_states = F.scaled_dot_product_attention(
312+
joint_hidden_states = dispatch_attention_fn(
322313
joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
323314
)
324315

325316
# Reshape back
326-
joint_hidden_states = joint_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
317+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
327318
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
328319

329320
# Split attention outputs back
@@ -512,12 +503,8 @@ def __init__(
512503
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
513504
)
514505

515-
# self.txt_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6)
516506
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
517507

518-
# self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
519-
# self.x_embedder = nn.Linear(in_channels, self.inner_dim)
520-
521508
self.img_in = nn.Linear(in_channels, self.inner_dim)
522509
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
523510

@@ -537,106 +524,6 @@ def __init__(
537524

538525
self.gradient_checkpointing = False
539526

540-
@property
541-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
542-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
543-
r"""
544-
Returns:
545-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
546-
indexed by its weight name.
547-
"""
548-
# set recursively
549-
processors = {}
550-
551-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
552-
if hasattr(module, "get_processor"):
553-
processors[f"{name}.processor"] = module.get_processor()
554-
555-
for sub_name, child in module.named_children():
556-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
557-
558-
return processors
559-
560-
for name, module in self.named_children():
561-
fn_recursive_add_processors(name, module, processors)
562-
563-
return processors
564-
565-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
566-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
567-
r"""
568-
Sets the attention processor to use to compute attention.
569-
570-
Parameters:
571-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
572-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
573-
for **all** `Attention` layers.
574-
575-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
576-
processor. This is strongly recommended when setting trainable attention processors.
577-
578-
"""
579-
count = len(self.attn_processors.keys())
580-
581-
if isinstance(processor, dict) and len(processor) != count:
582-
raise ValueError(
583-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
584-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
585-
)
586-
587-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
588-
if hasattr(module, "set_processor"):
589-
if not isinstance(processor, dict):
590-
module.set_processor(processor)
591-
else:
592-
module.set_processor(processor.pop(f"{name}.processor"))
593-
594-
for sub_name, child in module.named_children():
595-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
596-
597-
for name, module in self.named_children():
598-
fn_recursive_attn_processor(name, module, processor)
599-
600-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedQwenAttnProcessor2_0
601-
def fuse_qkv_projections(self):
602-
"""
603-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
604-
are fused. For cross-attention modules, key and value projection matrices are fused.
605-
606-
<Tip warning={true}>
607-
608-
This API is 🧪 experimental.
609-
610-
</Tip>
611-
"""
612-
self.original_attn_processors = None
613-
614-
for _, attn_processor in self.attn_processors.items():
615-
if "Added" in str(attn_processor.__class__.__name__):
616-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
617-
618-
raise ValueError("fuse_qkv_projections is currently not supported.")
619-
self.original_attn_processors = self.attn_processors
620-
621-
for module in self.modules():
622-
if isinstance(module, Attention):
623-
module.fuse_projections(fuse=True)
624-
# self.set_attn_processor(FusedQwenAttnProcessor2_0())
625-
626-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
627-
def unfuse_qkv_projections(self):
628-
"""Disables the fused QKV projection if enabled.
629-
630-
<Tip warning={true}>
631-
632-
This API is 🧪 experimental.
633-
634-
</Tip>
635-
636-
"""
637-
if self.original_attn_processors is not None:
638-
self.set_attn_processor(self.original_attn_processors)
639-
640527
def forward(
641528
self,
642529
hidden_states: torch.Tensor,

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 24 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Qwen2Tokenizer,
2323
)
2424

25-
from ...image_processor import PipelineImageInput, VaeImageProcessor
25+
from ...image_processor import VaeImageProcessor
2626
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
2727
from ...schedulers import FlowMatchEulerDiscreteScheduler
2828
from ...utils import (
@@ -210,7 +210,7 @@ def _get_qwen_prompt_embeds(
210210

211211
prompt = [prompt] if isinstance(prompt, str) else prompt
212212
batch_size = len(prompt)
213-
# import ipdb; ipdb.set_trace()
213+
214214
template = self.prompt_template_encode
215215
drop_idx = self.prompt_template_encode_start_idx
216216
txt = [template.format(e) for e in prompt]
@@ -478,21 +478,17 @@ def __call__(
478478
self,
479479
prompt: Union[str, List[str]] = None,
480480
negative_prompt: Union[str, List[str]] = None,
481-
true_cfg_scale: float = 1.0,
481+
true_cfg_scale: float = 4.0,
482482
height: Optional[int] = None,
483483
width: Optional[int] = None,
484-
num_inference_steps: int = 28,
484+
num_inference_steps: int = 50,
485485
sigmas: Optional[List[float]] = None,
486-
guidance_scale: float = 3.5,
486+
guidance_scale: float = 1.0,
487487
num_images_per_prompt: int = 1,
488488
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
489489
latents: Optional[torch.FloatTensor] = None,
490490
prompt_embeds: Optional[torch.FloatTensor] = None,
491491
prompt_embeds_mask: Optional[torch.FloatTensor] = None,
492-
ip_adapter_image: Optional[PipelineImageInput] = None,
493-
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
494-
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
495-
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
496492
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
497493
negative_prompt_embeds_mask: Optional[torch.FloatTensor] = None,
498494
output_type: Optional[str] = "pil",
@@ -699,38 +695,9 @@ def __call__(
699695
else:
700696
guidance = None
701697

702-
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
703-
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
704-
):
705-
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
706-
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
707-
708-
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
709-
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
710-
):
711-
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
712-
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
713-
714698
if self.joint_attention_kwargs is None:
715699
self._joint_attention_kwargs = {}
716700

717-
image_embeds = None
718-
negative_image_embeds = None
719-
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
720-
image_embeds = self.prepare_ip_adapter_image_embeds(
721-
ip_adapter_image,
722-
ip_adapter_image_embeds,
723-
device,
724-
batch_size * num_images_per_prompt,
725-
)
726-
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
727-
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
728-
negative_ip_adapter_image,
729-
negative_ip_adapter_image_embeds,
730-
device,
731-
batch_size * num_images_per_prompt,
732-
)
733-
734701
# 6. Denoising loop
735702
self.scheduler.set_begin_index(0)
736703
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -739,36 +706,34 @@ def __call__(
739706
continue
740707

741708
self._current_timestep = t
742-
if image_embeds is not None:
743-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
744709
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
745710
timestep = t.expand(latents.shape[0]).to(latents.dtype)
746-
noise_pred = self.transformer(
747-
hidden_states=latents,
748-
timestep=timestep / 1000,
749-
guidance=guidance,
750-
encoder_hidden_states_mask=prompt_embeds_mask,
751-
encoder_hidden_states=prompt_embeds,
752-
img_shapes=img_shapes,
753-
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
754-
joint_attention_kwargs=self.joint_attention_kwargs,
755-
return_dict=False,
756-
)[0]
757-
758-
if do_true_cfg:
759-
if negative_image_embeds is not None:
760-
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
761-
neg_noise_pred = self.transformer(
711+
with self.transformer.cache_context("cond"):
712+
noise_pred = self.transformer(
762713
hidden_states=latents,
763714
timestep=timestep / 1000,
764715
guidance=guidance,
765-
encoder_hidden_states_mask=negative_prompt_embeds_mask,
766-
encoder_hidden_states=negative_prompt_embeds,
716+
encoder_hidden_states_mask=prompt_embeds_mask,
717+
encoder_hidden_states=prompt_embeds,
767718
img_shapes=img_shapes,
768-
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
719+
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
769720
joint_attention_kwargs=self.joint_attention_kwargs,
770721
return_dict=False,
771722
)[0]
723+
724+
if do_true_cfg:
725+
with self.transformer.cache_context("uncond"):
726+
neg_noise_pred = self.transformer(
727+
hidden_states=latents,
728+
timestep=timestep / 1000,
729+
guidance=guidance,
730+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
731+
encoder_hidden_states=negative_prompt_embeds,
732+
img_shapes=img_shapes,
733+
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
734+
joint_attention_kwargs=self.joint_attention_kwargs,
735+
return_dict=False,
736+
)[0]
772737
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
773738

774739
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)

0 commit comments

Comments
 (0)