2525from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
2626from ...utils .torch_utils import maybe_allow_in_graph
2727from ..attention import FeedForward
28- from ..attention_processor import (
29- Attention ,
30- AttentionProcessor ,
31- )
28+ from ..attention_dispatch import dispatch_attention_fn
29+ from ..attention_processor import Attention
3230from ..cache_utils import CacheMixin
3331from ..embeddings import TimestepEmbedding , Timesteps
3432from ..modeling_outputs import Transformer2DModelOutput
@@ -107,7 +105,7 @@ def apply_rotary_emb_qwen(
107105
108106 Args:
109107 x (`torch.Tensor`):
110- Query or key tensor to apply rotary embeddings. [B, H, S , D] xk (torch.Tensor): Key tensor to apply
108+ Query or key tensor to apply rotary embeddings. [B, S, H , D] xk (torch.Tensor): Key tensor to apply
111109 freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
112110
113111 Returns:
@@ -135,6 +133,7 @@ def apply_rotary_emb_qwen(
135133 return out
136134 else :
137135 x_rotated = torch .view_as_complex (x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ))
136+ freqs_cis = freqs_cis .unsqueeze (1 )
138137 x_out = torch .view_as_real (x_rotated * freqs_cis ).flatten (3 )
139138
140139 return x_out .type_as (x )
@@ -148,7 +147,6 @@ def __init__(self, embedding_dim, pooled_projection_dim):
148147 self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
149148
150149 def forward (self , timestep , hidden_states ):
151- # import ipdb; ipdb.set_trace()
152150 timesteps_proj = self .time_proj (timestep )
153151 timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = hidden_states .dtype )) # (N, D)
154152
@@ -245,6 +243,8 @@ class QwenDoubleStreamAttnProcessor2_0:
245243 implements joint attention computation where text and image streams are processed together.
246244 """
247245
246+ _attention_backend = None
247+
248248 def __init__ (self ):
249249 if not hasattr (F , "scaled_dot_product_attention" ):
250250 raise ImportError (
@@ -263,8 +263,6 @@ def __call__(
263263 if encoder_hidden_states is None :
264264 raise ValueError ("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)" )
265265
266- batch_size = hidden_states .shape [0 ]
267- seq_img = hidden_states .shape [1 ]
268266 seq_txt = encoder_hidden_states .shape [1 ]
269267
270268 # Compute QKV for image stream (sample projections)
@@ -277,20 +275,14 @@ def __call__(
277275 txt_key = attn .add_k_proj (encoder_hidden_states )
278276 txt_value = attn .add_v_proj (encoder_hidden_states )
279277
280- inner_dim = img_key .shape [- 1 ]
281- head_dim = inner_dim // attn .heads
282-
283278 # Reshape for multi-head attention
284- def reshape_for_heads (tensor , seq_len ):
285- return tensor .view (batch_size , seq_len , attn .heads , head_dim ).transpose (1 , 2 )
286-
287- img_query = reshape_for_heads (img_query , seq_img )
288- img_key = reshape_for_heads (img_key , seq_img )
289- img_value = reshape_for_heads (img_value , seq_img )
279+ img_query = img_query .unflatten (- 1 , (attn .heads , - 1 ))
280+ img_key = img_key .unflatten (- 1 , (attn .heads , - 1 ))
281+ img_value = img_value .unflatten (- 1 , (attn .heads , - 1 ))
290282
291- txt_query = reshape_for_heads ( txt_query , seq_txt )
292- txt_key = reshape_for_heads ( txt_key , seq_txt )
293- txt_value = reshape_for_heads ( txt_value , seq_txt )
283+ txt_query = txt_query . unflatten ( - 1 , ( attn . heads , - 1 ) )
284+ txt_key = txt_key . unflatten ( - 1 , ( attn . heads , - 1 ) )
285+ txt_value = txt_value . unflatten ( - 1 , ( attn . heads , - 1 ) )
294286
295287 # Apply QK normalization
296288 if attn .norm_q is not None :
@@ -307,23 +299,22 @@ def reshape_for_heads(tensor, seq_len):
307299 img_freqs , txt_freqs = image_rotary_emb
308300 img_query = apply_rotary_emb_qwen (img_query , img_freqs , use_real = False )
309301 img_key = apply_rotary_emb_qwen (img_key , img_freqs , use_real = False )
310- # import ipdb; ipdb.set_trace()
311302 txt_query = apply_rotary_emb_qwen (txt_query , txt_freqs , use_real = False )
312303 txt_key = apply_rotary_emb_qwen (txt_key , txt_freqs , use_real = False )
313304
314305 # Concatenate for joint attention
315306 # Order: [text, image]
316- joint_query = torch .cat ([txt_query , img_query ], dim = 2 )
317- joint_key = torch .cat ([txt_key , img_key ], dim = 2 )
318- joint_value = torch .cat ([txt_value , img_value ], dim = 2 )
307+ joint_query = torch .cat ([txt_query , img_query ], dim = 1 )
308+ joint_key = torch .cat ([txt_key , img_key ], dim = 1 )
309+ joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
319310
320311 # Compute joint attention
321- joint_hidden_states = F . scaled_dot_product_attention (
312+ joint_hidden_states = dispatch_attention_fn (
322313 joint_query , joint_key , joint_value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
323314 )
324315
325316 # Reshape back
326- joint_hidden_states = joint_hidden_states .transpose ( 1 , 2 ). reshape ( batch_size , - 1 , attn . heads * head_dim )
317+ joint_hidden_states = joint_hidden_states .flatten ( 2 , 3 )
327318 joint_hidden_states = joint_hidden_states .to (joint_query .dtype )
328319
329320 # Split attention outputs back
@@ -512,12 +503,8 @@ def __init__(
512503 embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
513504 )
514505
515- # self.txt_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6)
516506 self .txt_norm = RMSNorm (joint_attention_dim , eps = 1e-6 )
517507
518- # self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
519- # self.x_embedder = nn.Linear(in_channels, self.inner_dim)
520-
521508 self .img_in = nn .Linear (in_channels , self .inner_dim )
522509 self .txt_in = nn .Linear (joint_attention_dim , self .inner_dim )
523510
@@ -537,106 +524,6 @@ def __init__(
537524
538525 self .gradient_checkpointing = False
539526
540- @property
541- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
542- def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
543- r"""
544- Returns:
545- `dict` of attention processors: A dictionary containing all attention processors used in the model with
546- indexed by its weight name.
547- """
548- # set recursively
549- processors = {}
550-
551- def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
552- if hasattr (module , "get_processor" ):
553- processors [f"{ name } .processor" ] = module .get_processor ()
554-
555- for sub_name , child in module .named_children ():
556- fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
557-
558- return processors
559-
560- for name , module in self .named_children ():
561- fn_recursive_add_processors (name , module , processors )
562-
563- return processors
564-
565- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
566- def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
567- r"""
568- Sets the attention processor to use to compute attention.
569-
570- Parameters:
571- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
572- The instantiated processor class or a dictionary of processor classes that will be set as the processor
573- for **all** `Attention` layers.
574-
575- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
576- processor. This is strongly recommended when setting trainable attention processors.
577-
578- """
579- count = len (self .attn_processors .keys ())
580-
581- if isinstance (processor , dict ) and len (processor ) != count :
582- raise ValueError (
583- f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
584- f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
585- )
586-
587- def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
588- if hasattr (module , "set_processor" ):
589- if not isinstance (processor , dict ):
590- module .set_processor (processor )
591- else :
592- module .set_processor (processor .pop (f"{ name } .processor" ))
593-
594- for sub_name , child in module .named_children ():
595- fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
596-
597- for name , module in self .named_children ():
598- fn_recursive_attn_processor (name , module , processor )
599-
600- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedQwenAttnProcessor2_0
601- def fuse_qkv_projections (self ):
602- """
603- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
604- are fused. For cross-attention modules, key and value projection matrices are fused.
605-
606- <Tip warning={true}>
607-
608- This API is 🧪 experimental.
609-
610- </Tip>
611- """
612- self .original_attn_processors = None
613-
614- for _ , attn_processor in self .attn_processors .items ():
615- if "Added" in str (attn_processor .__class__ .__name__ ):
616- raise ValueError ("`fuse_qkv_projections()` is not supported for models having added KV projections." )
617-
618- raise ValueError ("fuse_qkv_projections is currently not supported." )
619- self .original_attn_processors = self .attn_processors
620-
621- for module in self .modules ():
622- if isinstance (module , Attention ):
623- module .fuse_projections (fuse = True )
624- # self.set_attn_processor(FusedQwenAttnProcessor2_0())
625-
626- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
627- def unfuse_qkv_projections (self ):
628- """Disables the fused QKV projection if enabled.
629-
630- <Tip warning={true}>
631-
632- This API is 🧪 experimental.
633-
634- </Tip>
635-
636- """
637- if self .original_attn_processors is not None :
638- self .set_attn_processor (self .original_attn_processors )
639-
640527 def forward (
641528 self ,
642529 hidden_states : torch .Tensor ,
0 commit comments