@@ -997,6 +997,8 @@ def forward(
997997class  MochiAttnProcessor2_0 :
998998    """Attention processor used in Mochi.""" 
999999
1000+     _attention_backend  =  None 
1001+ 
10001002    def  __init__ (self ):
10011003        if  not  hasattr (F , "scaled_dot_product_attention" ):
10021004            raise  ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
@@ -1074,7 +1076,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
10741076            valid_key  =  torch .cat ([key [idx  : idx  +  1 ], valid_encoder_key ], dim = 2 )
10751077            valid_value  =  torch .cat ([value [idx  : idx  +  1 ], valid_encoder_value ], dim = 2 )
10761078
1077-             attn_output  =  dispatch_attention_fn (valid_query , valid_key , valid_value , dropout_p = 0.0 , is_causal = False )
1079+             attn_output  =  dispatch_attention_fn (
1080+                 valid_query , valid_key , valid_value , dropout_p = 0.0 , is_causal = False , backend = self ._attention_backend 
1081+             )
10781082            valid_sequence_length  =  attn_output .size (2 )
10791083            attn_output  =  F .pad (attn_output , (0 , 0 , 0 , total_length  -  valid_sequence_length ))
10801084            attn_outputs .append (attn_output )
@@ -2274,6 +2278,8 @@ def __call__(
22742278class  FluxAttnProcessor2_0 :
22752279    """Attention processor used typically in processing the SD3-like self-attention projections.""" 
22762280
2281+     _attention_backend  =  None 
2282+ 
22772283    def  __init__ (self ):
22782284        if  not  hasattr (F , "scaled_dot_product_attention" ):
22792285            raise  ImportError ("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -2339,7 +2345,13 @@ def __call__(
23392345            key  =  apply_rotary_emb (key , image_rotary_emb )
23402346
23412347        hidden_states  =  dispatch_attention_fn (
2342-             query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False 
2348+             query ,
2349+             key ,
2350+             value ,
2351+             attn_mask = attention_mask ,
2352+             dropout_p = 0.0 ,
2353+             is_causal = False ,
2354+             backend = self ._attention_backend ,
23432355        )
23442356
23452357        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
@@ -2366,6 +2378,8 @@ def __call__(
23662378class  FluxAttnProcessor2_0_NPU :
23672379    """Attention processor used typically in processing the SD3-like self-attention projections.""" 
23682380
2381+     _attention_backend  =  None 
2382+ 
23692383    def  __init__ (self ):
23702384        if  not  hasattr (F , "scaled_dot_product_attention" ):
23712385            raise  ImportError (
@@ -2448,7 +2462,9 @@ def __call__(
24482462                inner_precise = 0 ,
24492463            )[0 ]
24502464        else :
2451-             hidden_states  =  dispatch_attention_fn (query , key , value , dropout_p = 0.0 , is_causal = False )
2465+             hidden_states  =  dispatch_attention_fn (
2466+                 query , key , value , dropout_p = 0.0 , is_causal = False , backend = self ._attention_backend 
2467+             )
24522468        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
24532469        hidden_states  =  hidden_states .to (query .dtype )
24542470
@@ -2472,6 +2488,8 @@ def __call__(
24722488class  FusedFluxAttnProcessor2_0 :
24732489    """Attention processor used typically in processing the SD3-like self-attention projections.""" 
24742490
2491+     _attention_backend  =  None 
2492+ 
24752493    def  __init__ (self ):
24762494        if  not  hasattr (F , "scaled_dot_product_attention" ):
24772495            raise  ImportError (
@@ -2542,7 +2560,9 @@ def __call__(
25422560            query  =  apply_rotary_emb (query , image_rotary_emb )
25432561            key  =  apply_rotary_emb (key , image_rotary_emb )
25442562
2545-         hidden_states  =  dispatch_attention_fn (query , key , value , dropout_p = 0.0 , is_causal = False )
2563+         hidden_states  =  dispatch_attention_fn (
2564+             query , key , value , dropout_p = 0.0 , is_causal = False , backend = self ._attention_backend 
2565+         )
25462566
25472567        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
25482568        hidden_states  =  hidden_states .to (query .dtype )
@@ -2567,6 +2587,8 @@ def __call__(
25672587class  FusedFluxAttnProcessor2_0_NPU :
25682588    """Attention processor used typically in processing the SD3-like self-attention projections.""" 
25692589
2590+     _attention_backend  =  None 
2591+ 
25702592    def  __init__ (self ):
25712593        if  not  hasattr (F , "scaled_dot_product_attention" ):
25722594            raise  ImportError (
@@ -2653,7 +2675,9 @@ def __call__(
26532675                inner_precise = 0 ,
26542676            )[0 ]
26552677        else :
2656-             hidden_states  =  dispatch_attention_fn (query , key , value , dropout_p = 0.0 , is_causal = False )
2678+             hidden_states  =  dispatch_attention_fn (
2679+                 query , key , value , dropout_p = 0.0 , is_causal = False , backend = self ._attention_backend 
2680+             )
26572681
26582682        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
26592683        hidden_states  =  hidden_states .to (query .dtype )
@@ -2678,6 +2702,8 @@ def __call__(
26782702class  FluxIPAdapterJointAttnProcessor2_0 (torch .nn .Module ):
26792703    """Flux Attention processor for IP-Adapter.""" 
26802704
2705+     _attention_backend  =  None 
2706+ 
26812707    def  __init__ (
26822708        self , hidden_size : int , cross_attention_dim : int , num_tokens = (4 ,), scale = 1.0 , device = None , dtype = None 
26832709    ):
@@ -2775,7 +2801,9 @@ def __call__(
27752801            query  =  apply_rotary_emb (query , image_rotary_emb )
27762802            key  =  apply_rotary_emb (key , image_rotary_emb )
27772803
2778-         hidden_states  =  dispatch_attention_fn (query , key , value , dropout_p = 0.0 , is_causal = False )
2804+         hidden_states  =  dispatch_attention_fn (
2805+             query , key , value , dropout_p = 0.0 , is_causal = False , backend = self ._attention_backend 
2806+         )
27792807        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
27802808        hidden_states  =  hidden_states .to (query .dtype )
27812809
@@ -2806,7 +2834,13 @@ def __call__(
28062834                # the output of sdp = (batch, num_heads, seq_len, head_dim) 
28072835                # TODO: add support for attn.scale when we move to Torch 2.1 
28082836                current_ip_hidden_states  =  dispatch_attention_fn (
2809-                     ip_query , ip_key , ip_value , attn_mask = None , dropout_p = 0.0 , is_causal = False 
2837+                     ip_query ,
2838+                     ip_key ,
2839+                     ip_value ,
2840+                     attn_mask = None ,
2841+                     dropout_p = 0.0 ,
2842+                     is_causal = False ,
2843+                     backend = self ._attention_backend ,
28102844                )
28112845                current_ip_hidden_states  =  current_ip_hidden_states .transpose (1 , 2 ).reshape (
28122846                    batch_size , - 1 , attn .heads  *  head_dim 
@@ -2825,6 +2859,8 @@ class CogVideoXAttnProcessor2_0:
28252859    query and key vectors, but does not include spatial normalization. 
28262860    """ 
28272861
2862+     _attention_backend  =  None 
2863+ 
28282864    def  __init__ (self ):
28292865        if  not  hasattr (F , "scaled_dot_product_attention" ):
28302866            raise  ImportError ("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -2872,7 +2908,13 @@ def __call__(
28722908                key [:, :, text_seq_length :] =  apply_rotary_emb (key [:, :, text_seq_length :], image_rotary_emb )
28732909
28742910        hidden_states  =  dispatch_attention_fn (
2875-             query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False 
2911+             query ,
2912+             key ,
2913+             value ,
2914+             attn_mask = attention_mask ,
2915+             dropout_p = 0.0 ,
2916+             is_causal = False ,
2917+             backend = self ._attention_backend ,
28762918        )
28772919
28782920        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
@@ -2894,6 +2936,8 @@ class FusedCogVideoXAttnProcessor2_0:
28942936    query and key vectors, but does not include spatial normalization. 
28952937    """ 
28962938
2939+     _attention_backend  =  None 
2940+ 
28972941    def  __init__ (self ):
28982942        if  not  hasattr (F , "scaled_dot_product_attention" ):
28992943            raise  ImportError ("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -2943,7 +2987,13 @@ def __call__(
29432987                key [:, :, text_seq_length :] =  apply_rotary_emb (key [:, :, text_seq_length :], image_rotary_emb )
29442988
29452989        hidden_states  =  dispatch_attention_fn (
2946-             query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False 
2990+             query ,
2991+             key ,
2992+             value ,
2993+             attn_mask = attention_mask ,
2994+             dropout_p = 0.0 ,
2995+             is_causal = False ,
2996+             backend = self ._attention_backend ,
29472997        )
29482998
29492999        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
@@ -3129,9 +3179,10 @@ class AttnProcessorNPU:
31293179    Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If 
31303180    fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is 
31313181    not significant. 
3132- 
31333182    """ 
31343183
3184+     _attention_backend  =  None 
3185+ 
31353186    def  __init__ (self ):
31363187        if  not  is_torch_npu_available ():
31373188            raise  ImportError ("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices." )
@@ -3216,7 +3267,13 @@ def __call__(
32163267        else :
32173268            # TODO: add support for attn.scale when we move to Torch 2.1 
32183269            hidden_states  =  dispatch_attention_fn (
3219-                 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False 
3270+                 query ,
3271+                 key ,
3272+                 value ,
3273+                 attn_mask = attention_mask ,
3274+                 dropout_p = 0.0 ,
3275+                 is_causal = False ,
3276+                 backend = self ._attention_backend ,
32203277            )
32213278
32223279        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
@@ -3243,6 +3300,8 @@ class AttnProcessor2_0:
32433300    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 
32443301    """ 
32453302
3303+     _attention_backend  =  None 
3304+ 
32463305    def  __init__ (self ):
32473306        if  not  hasattr (F , "scaled_dot_product_attention" ):
32483307            raise  ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -3310,7 +3369,13 @@ def __call__(
33103369        # the output of sdp = (batch, num_heads, seq_len, head_dim) 
33113370        # TODO: add support for attn.scale when we move to Torch 2.1 
33123371        hidden_states  =  dispatch_attention_fn (
3313-             query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False 
3372+             query ,
3373+             key ,
3374+             value ,
3375+             attn_mask = attention_mask ,
3376+             dropout_p = 0.0 ,
3377+             is_causal = False ,
3378+             backend = self ._attention_backend ,
33143379        )
33153380
33163381        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
@@ -3553,6 +3618,8 @@ class MochiVaeAttnProcessor2_0:
35533618    Attention processor used in Mochi VAE. 
35543619    """ 
35553620
3621+     _attention_backend  =  None 
3622+ 
35563623    def  __init__ (self ):
35573624        if  not  hasattr (F , "scaled_dot_product_attention" ):
35583625            raise  ImportError ("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
@@ -3614,7 +3681,13 @@ def __call__(
36143681        # the output of sdp = (batch, num_heads, seq_len, head_dim) 
36153682        # TODO: add support for attn.scale when we move to Torch 2.1 
36163683        hidden_states  =  dispatch_attention_fn (
3617-             query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = attn .is_causal 
3684+             query ,
3685+             key ,
3686+             value ,
3687+             attn_mask = attention_mask ,
3688+             dropout_p = 0.0 ,
3689+             is_causal = attn .is_causal ,
3690+             backend = self ._attention_backend ,
36183691        )
36193692
36203693        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
0 commit comments