@@ -297,7 +297,7 @@ def __init__(
297297        self .set_processor (processor )
298298
299299    def  set_use_xla_flash_attention (
300-         self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] =  None 
300+         self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] =  None ,  ** kwargs 
301301    ) ->  None :
302302        r""" 
303303        Set whether to use xla flash attention from `torch_xla` or not. 
@@ -316,7 +316,10 @@ def set_use_xla_flash_attention(
316316            elif  is_spmd () and  is_torch_xla_version ("<" , "2.4" ):
317317                raise  "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" 
318318            else :
319-                 processor  =  XLAFlashAttnProcessor2_0 (partition_spec )
319+                 if  len (kwargs ) >  0  and  kwargs .get ("is_flux" , None ):
320+                     processor  =  XLAFluxFlashAttnProcessor2_0 (partition_spec )
321+                 else :
322+                     processor  =  XLAFlashAttnProcessor2_0 (partition_spec )
320323        else :
321324            processor  =  (
322325                AttnProcessor2_0 () if  hasattr (F , "scaled_dot_product_attention" ) and  self .scale_qk  else  AttnProcessor ()
@@ -2318,11 +2321,7 @@ def __call__(
23182321            query  =  apply_rotary_emb (query , image_rotary_emb )
23192322            key  =  apply_rotary_emb (key , image_rotary_emb )
23202323
2321-         if  XLA_AVAILABLE :
2322-            query  /=  math .sqrt (head_dim )
2323-            hidden_states  =  flash_attention (query , key , value , causal = False )
2324-         else :
2325-             hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2324+         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
23262325
23272326        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
23282327        hidden_states  =  hidden_states .to (query .dtype )
@@ -2523,12 +2522,8 @@ def __call__(
25232522
25242523            query  =  apply_rotary_emb (query , image_rotary_emb )
25252524            key  =  apply_rotary_emb (key , image_rotary_emb )
2526- 
2527-         if  XLA_AVAILABLE :
2528-           query  /=  math .sqrt (head_dim )
2529-           hidden_states  =  flash_attention (query , key , value )
2530-         else :    
2531-           hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2525+  
2526+         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
25322527
25332528        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
25342529        hidden_states  =  hidden_states .to (query .dtype )
@@ -3430,6 +3425,106 @@ def __call__(
34303425        return  hidden_states 
34313426
34323427
3428+ class  XLAFluxFlashAttnProcessor2_0 :
3429+     r""" 
3430+     Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. 
3431+     """ 
3432+ 
3433+     def  __init__ (self , partition_spec : Optional [Tuple [Optional [str ], ...]] =  None ):
3434+         if  not  hasattr (F , "scaled_dot_product_attention" ):
3435+             raise  ImportError (
3436+                 "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 
3437+             )
3438+         if  is_torch_xla_version ("<" , "2.3" ):
3439+             raise  ImportError ("XLA flash attention requires torch_xla version >= 2.3." )
3440+         if  is_spmd () and  is_torch_xla_version ("<" , "2.4" ):
3441+             raise  ImportError ("SPMD support for XLA flash attention needs torch_xla version >= 2.4." )
3442+         self .partition_spec  =  partition_spec 
3443+ 
3444+     def  __call__ (
3445+         self ,
3446+         attn : Attention ,
3447+         hidden_states : torch .FloatTensor ,
3448+         encoder_hidden_states : torch .FloatTensor  =  None ,
3449+         attention_mask : Optional [torch .FloatTensor ] =  None ,
3450+         image_rotary_emb : Optional [torch .Tensor ] =  None ,
3451+     ) ->  torch .FloatTensor :
3452+         batch_size , _ , _  =  hidden_states .shape  if  encoder_hidden_states  is  None  else  encoder_hidden_states .shape 
3453+ 
3454+         # `sample` projections. 
3455+         query  =  attn .to_q (hidden_states )
3456+         key  =  attn .to_k (hidden_states )
3457+         value  =  attn .to_v (hidden_states )
3458+ 
3459+         inner_dim  =  key .shape [- 1 ]
3460+         head_dim  =  inner_dim  //  attn .heads 
3461+ 
3462+         query  =  query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3463+         key  =  key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3464+         value  =  value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3465+ 
3466+         if  attn .norm_q  is  not   None :
3467+             query  =  attn .norm_q (query )
3468+         if  attn .norm_k  is  not   None :
3469+             key  =  attn .norm_k (key )
3470+ 
3471+         # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 
3472+         if  encoder_hidden_states  is  not   None :
3473+             # `context` projections. 
3474+             encoder_hidden_states_query_proj  =  attn .add_q_proj (encoder_hidden_states )
3475+             encoder_hidden_states_key_proj  =  attn .add_k_proj (encoder_hidden_states )
3476+             encoder_hidden_states_value_proj  =  attn .add_v_proj (encoder_hidden_states )
3477+ 
3478+             encoder_hidden_states_query_proj  =  encoder_hidden_states_query_proj .view (
3479+                 batch_size , - 1 , attn .heads , head_dim 
3480+             ).transpose (1 , 2 )
3481+             encoder_hidden_states_key_proj  =  encoder_hidden_states_key_proj .view (
3482+                 batch_size , - 1 , attn .heads , head_dim 
3483+             ).transpose (1 , 2 )
3484+             encoder_hidden_states_value_proj  =  encoder_hidden_states_value_proj .view (
3485+                 batch_size , - 1 , attn .heads , head_dim 
3486+             ).transpose (1 , 2 )
3487+ 
3488+             if  attn .norm_added_q  is  not   None :
3489+                 encoder_hidden_states_query_proj  =  attn .norm_added_q (encoder_hidden_states_query_proj )
3490+             if  attn .norm_added_k  is  not   None :
3491+                 encoder_hidden_states_key_proj  =  attn .norm_added_k (encoder_hidden_states_key_proj )
3492+ 
3493+             # attention 
3494+             query  =  torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
3495+             key  =  torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
3496+             value  =  torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
3497+ 
3498+         if  image_rotary_emb  is  not   None :
3499+             from  .embeddings  import  apply_rotary_emb 
3500+ 
3501+             query  =  apply_rotary_emb (query , image_rotary_emb )
3502+             key  =  apply_rotary_emb (key , image_rotary_emb )
3503+ 
3504+         query  /=  math .sqrt (head_dim )
3505+         hidden_states  =  flash_attention (query , key , value , causal = False )
3506+         
3507+         hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
3508+         hidden_states  =  hidden_states .to (query .dtype )
3509+ 
3510+         if  encoder_hidden_states  is  not   None :
3511+             encoder_hidden_states , hidden_states  =  (
3512+                 hidden_states [:, : encoder_hidden_states .shape [1 ]],
3513+                 hidden_states [:, encoder_hidden_states .shape [1 ] :],
3514+             )
3515+ 
3516+             # linear proj 
3517+             hidden_states  =  attn .to_out [0 ](hidden_states )
3518+             # dropout 
3519+             hidden_states  =  attn .to_out [1 ](hidden_states )
3520+ 
3521+             encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
3522+ 
3523+             return  hidden_states , encoder_hidden_states 
3524+         else :
3525+             return  hidden_states 
3526+         
3527+ 
34333528class  MochiVaeAttnProcessor2_0 :
34343529    r""" 
34353530    Attention processor used in Mochi VAE. 
0 commit comments