@@ -276,10 +276,16 @@ def __init__(
276276 # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
277277 # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
278278 # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
279+ # If torch_xla is available, we use pallas flash attention kernel to improve the performance.
279280 if processor is None :
280- processor = (
281- AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
282- )
281+ if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
282+ if is_torch_xla_available :
283+ processor = XLAFlashAttnProcessor2_0 ()
284+ else :
285+ processor = AttnProcessor2_0 ()
286+ else :
287+ processor = AttnProcessor ()
288+
283289 self .set_processor (processor )
284290
285291 def set_use_npu_flash_attention (self , use_npu_flash_attention : bool ) -> None :
@@ -2644,6 +2650,102 @@ def __init__(self):
26442650 if not hasattr (F , "scaled_dot_product_attention" ):
26452651 raise ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
26462652
2653+ def __call__ (
2654+ self ,
2655+ attn : Attention ,
2656+ hidden_states : torch .Tensor ,
2657+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
2658+ attention_mask : Optional [torch .Tensor ] = None ,
2659+ temb : Optional [torch .Tensor ] = None ,
2660+ * args ,
2661+ ** kwargs ,
2662+ ) -> torch .Tensor :
2663+ if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
2664+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2665+ deprecate ("scale" , "1.0.0" , deprecation_message )
2666+
2667+ residual = hidden_states
2668+ if attn .spatial_norm is not None :
2669+ hidden_states = attn .spatial_norm (hidden_states , temb )
2670+
2671+ input_ndim = hidden_states .ndim
2672+
2673+ if input_ndim == 4 :
2674+ batch_size , channel , height , width = hidden_states .shape
2675+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
2676+
2677+ batch_size , sequence_length , _ = (
2678+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2679+ )
2680+
2681+ if attention_mask is not None :
2682+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
2683+ # scaled_dot_product_attention expects attention_mask shape to be
2684+ # (batch, heads, source_length, target_length)
2685+ attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
2686+
2687+ if attn .group_norm is not None :
2688+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
2689+
2690+ query = attn .to_q (hidden_states )
2691+
2692+ if encoder_hidden_states is None :
2693+ encoder_hidden_states = hidden_states
2694+ elif attn .norm_cross :
2695+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
2696+
2697+ key = attn .to_k (encoder_hidden_states )
2698+ value = attn .to_v (encoder_hidden_states )
2699+
2700+ inner_dim = key .shape [- 1 ]
2701+ head_dim = inner_dim // attn .heads
2702+
2703+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2704+
2705+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2706+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2707+
2708+ if attn .norm_q is not None :
2709+ query = attn .norm_q (query )
2710+ if attn .norm_k is not None :
2711+ key = attn .norm_k (key )
2712+
2713+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2714+ # TODO: add support for attn.scale when we move to Torch 2.1
2715+ hidden_states = F .scaled_dot_product_attention (
2716+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
2717+ )
2718+
2719+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2720+ hidden_states = hidden_states .to (query .dtype )
2721+
2722+ # linear proj
2723+ hidden_states = attn .to_out [0 ](hidden_states )
2724+ # dropout
2725+ hidden_states = attn .to_out [1 ](hidden_states )
2726+
2727+ if input_ndim == 4 :
2728+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
2729+
2730+ if attn .residual_connection :
2731+ hidden_states = hidden_states + residual
2732+
2733+ hidden_states = hidden_states / attn .rescale_output_factor
2734+
2735+ return hidden_states
2736+
2737+
2738+ class XLAFlashAttnProcessor2_0 :
2739+ r"""
2740+ Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla).
2741+ """
2742+
2743+ def __init__ (self ):
2744+ if not hasattr (F , "scaled_dot_product_attention" ):
2745+ raise ImportError ("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
2746+ if not is_torch_xla_available :
2747+ raise ImportError ("XLAFlashAttnProcessor2_0 required torch_xla package." )
2748+
26472749 def __call__ (
26482750 self ,
26492751 attn : Attention ,
0 commit comments