@@ -297,7 +297,7 @@ def __init__(
297297 self .set_processor (processor )
298298
299299 def set_use_xla_flash_attention (
300- self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] = None
300+ self , use_xla_flash_attention : bool , partition_spec : Optional [Tuple [Optional [str ], ...]] = None , ** kwargs
301301 ) -> None :
302302 r"""
303303 Set whether to use xla flash attention from `torch_xla` or not.
@@ -316,7 +316,10 @@ def set_use_xla_flash_attention(
316316 elif is_spmd () and is_torch_xla_version ("<" , "2.4" ):
317317 raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318318 else :
319- processor = XLAFlashAttnProcessor2_0 (partition_spec )
319+ if len (kwargs ) > 0 and kwargs .get ("is_flux" , None ):
320+ processor = XLAFluxFlashAttnProcessor2_0 (partition_spec )
321+ else :
322+ processor = XLAFlashAttnProcessor2_0 (partition_spec )
320323 else :
321324 processor = (
322325 AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
@@ -2318,11 +2321,7 @@ def __call__(
23182321 query = apply_rotary_emb (query , image_rotary_emb )
23192322 key = apply_rotary_emb (key , image_rotary_emb )
23202323
2321- if XLA_AVAILABLE :
2322- query /= math .sqrt (head_dim )
2323- hidden_states = flash_attention (query , key , value , causal = False )
2324- else :
2325- hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2324+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
23262325
23272326 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
23282327 hidden_states = hidden_states .to (query .dtype )
@@ -2523,12 +2522,8 @@ def __call__(
25232522
25242523 query = apply_rotary_emb (query , image_rotary_emb )
25252524 key = apply_rotary_emb (key , image_rotary_emb )
2526-
2527- if XLA_AVAILABLE :
2528- query /= math .sqrt (head_dim )
2529- hidden_states = flash_attention (query , key , value )
2530- else :
2531- hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2525+
2526+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
25322527
25332528 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
25342529 hidden_states = hidden_states .to (query .dtype )
@@ -3430,6 +3425,106 @@ def __call__(
34303425 return hidden_states
34313426
34323427
3428+ class XLAFluxFlashAttnProcessor2_0 :
3429+ r"""
3430+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3431+ """
3432+
3433+ def __init__ (self , partition_spec : Optional [Tuple [Optional [str ], ...]] = None ):
3434+ if not hasattr (F , "scaled_dot_product_attention" ):
3435+ raise ImportError (
3436+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3437+ )
3438+ if is_torch_xla_version ("<" , "2.3" ):
3439+ raise ImportError ("XLA flash attention requires torch_xla version >= 2.3." )
3440+ if is_spmd () and is_torch_xla_version ("<" , "2.4" ):
3441+ raise ImportError ("SPMD support for XLA flash attention needs torch_xla version >= 2.4." )
3442+ self .partition_spec = partition_spec
3443+
3444+ def __call__ (
3445+ self ,
3446+ attn : Attention ,
3447+ hidden_states : torch .FloatTensor ,
3448+ encoder_hidden_states : torch .FloatTensor = None ,
3449+ attention_mask : Optional [torch .FloatTensor ] = None ,
3450+ image_rotary_emb : Optional [torch .Tensor ] = None ,
3451+ ) -> torch .FloatTensor :
3452+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
3453+
3454+ # `sample` projections.
3455+ query = attn .to_q (hidden_states )
3456+ key = attn .to_k (hidden_states )
3457+ value = attn .to_v (hidden_states )
3458+
3459+ inner_dim = key .shape [- 1 ]
3460+ head_dim = inner_dim // attn .heads
3461+
3462+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3463+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3464+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
3465+
3466+ if attn .norm_q is not None :
3467+ query = attn .norm_q (query )
3468+ if attn .norm_k is not None :
3469+ key = attn .norm_k (key )
3470+
3471+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3472+ if encoder_hidden_states is not None :
3473+ # `context` projections.
3474+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
3475+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
3476+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
3477+
3478+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
3479+ batch_size , - 1 , attn .heads , head_dim
3480+ ).transpose (1 , 2 )
3481+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
3482+ batch_size , - 1 , attn .heads , head_dim
3483+ ).transpose (1 , 2 )
3484+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
3485+ batch_size , - 1 , attn .heads , head_dim
3486+ ).transpose (1 , 2 )
3487+
3488+ if attn .norm_added_q is not None :
3489+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
3490+ if attn .norm_added_k is not None :
3491+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
3492+
3493+ # attention
3494+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
3495+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
3496+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
3497+
3498+ if image_rotary_emb is not None :
3499+ from .embeddings import apply_rotary_emb
3500+
3501+ query = apply_rotary_emb (query , image_rotary_emb )
3502+ key = apply_rotary_emb (key , image_rotary_emb )
3503+
3504+ query /= math .sqrt (head_dim )
3505+ hidden_states = flash_attention (query , key , value , causal = False )
3506+
3507+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
3508+ hidden_states = hidden_states .to (query .dtype )
3509+
3510+ if encoder_hidden_states is not None :
3511+ encoder_hidden_states , hidden_states = (
3512+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
3513+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
3514+ )
3515+
3516+ # linear proj
3517+ hidden_states = attn .to_out [0 ](hidden_states )
3518+ # dropout
3519+ hidden_states = attn .to_out [1 ](hidden_states )
3520+
3521+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
3522+
3523+ return hidden_states , encoder_hidden_states
3524+ else :
3525+ return hidden_states
3526+
3527+
34333528class MochiVaeAttnProcessor2_0 :
34343529 r"""
34353530 Attention processor used in Mochi VAE.
0 commit comments