@@ -1778,28 +1778,7 @@ def __call__(
17781778            query  =  apply_rotary_emb (query , image_rotary_emb )
17791779            key  =  apply_rotary_emb (key , image_rotary_emb )
17801780
1781-         if  is_torch_npu_available ():
1782-             if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
1783-                 hidden_states  =  torch_npu .npu_fusion_attention (
1784-                     query ,
1785-                     key ,
1786-                     value ,
1787-                     attn .heads ,
1788-                     input_layout = "BNSD" ,
1789-                     pse = None ,
1790-                     scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
1791-                     pre_tockens = 65536 ,
1792-                     next_tockens = 65536 ,
1793-                     keep_prob = 1.0 ,
1794-                     sync = False ,
1795-                     inner_precise = 0 ,
1796-                 )[0 ]
1797-             else :
1798-                 hidden_states  =  F .scaled_dot_product_attention (
1799-                     query , key , value , dropout_p = 0.0 , is_causal = False 
1800-                 )
1801-         else :
1802-             hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1781+         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
18031782        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
18041783        hidden_states  =  hidden_states .to (query .dtype )
18051784
@@ -1819,6 +1798,110 @@ def __call__(
18191798        else :
18201799            return  hidden_states 
18211800
1801+ class  FluxAttnProcessor2_0_NPU :
1802+     """Attention processor used typically in processing the SD3-like self-attention projections.""" 
1803+ 
1804+     def  __init__ (self ):
1805+         if  not  hasattr (F , "scaled_dot_product_attention" ):
1806+             raise  ImportError ("FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0." )
1807+ 
1808+     def  __call__ (
1809+         self ,
1810+         attn : Attention ,
1811+         hidden_states : torch .FloatTensor ,
1812+         encoder_hidden_states : torch .FloatTensor  =  None ,
1813+         attention_mask : Optional [torch .FloatTensor ] =  None ,
1814+         image_rotary_emb : Optional [torch .Tensor ] =  None ,
1815+     ) ->  torch .FloatTensor :
1816+         batch_size , _ , _  =  hidden_states .shape  if  encoder_hidden_states  is  None  else  encoder_hidden_states .shape 
1817+ 
1818+         # `sample` projections. 
1819+         query  =  attn .to_q (hidden_states )
1820+         key  =  attn .to_k (hidden_states )
1821+         value  =  attn .to_v (hidden_states )
1822+ 
1823+         inner_dim  =  key .shape [- 1 ]
1824+         head_dim  =  inner_dim  //  attn .heads 
1825+ 
1826+         query  =  query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1827+         key  =  key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1828+         value  =  value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1829+ 
1830+         if  attn .norm_q  is  not None :
1831+             query  =  attn .norm_q (query )
1832+         if  attn .norm_k  is  not None :
1833+             key  =  attn .norm_k (key )
1834+ 
1835+         # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 
1836+         if  encoder_hidden_states  is  not None :
1837+             # `context` projections. 
1838+             encoder_hidden_states_query_proj  =  attn .add_q_proj (encoder_hidden_states )
1839+             encoder_hidden_states_key_proj  =  attn .add_k_proj (encoder_hidden_states )
1840+             encoder_hidden_states_value_proj  =  attn .add_v_proj (encoder_hidden_states )
1841+ 
1842+             encoder_hidden_states_query_proj  =  encoder_hidden_states_query_proj .view (
1843+                 batch_size , - 1 , attn .heads , head_dim 
1844+             ).transpose (1 , 2 )
1845+             encoder_hidden_states_key_proj  =  encoder_hidden_states_key_proj .view (
1846+                 batch_size , - 1 , attn .heads , head_dim 
1847+             ).transpose (1 , 2 )
1848+             encoder_hidden_states_value_proj  =  encoder_hidden_states_value_proj .view (
1849+                 batch_size , - 1 , attn .heads , head_dim 
1850+             ).transpose (1 , 2 )
1851+ 
1852+             if  attn .norm_added_q  is  not None :
1853+                 encoder_hidden_states_query_proj  =  attn .norm_added_q (encoder_hidden_states_query_proj )
1854+             if  attn .norm_added_k  is  not None :
1855+                 encoder_hidden_states_key_proj  =  attn .norm_added_k (encoder_hidden_states_key_proj )
1856+ 
1857+             # attention 
1858+             query  =  torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1859+             key  =  torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
1860+             value  =  torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
1861+ 
1862+         if  image_rotary_emb  is  not None :
1863+             from  .embeddings  import  apply_rotary_emb 
1864+ 
1865+             query  =  apply_rotary_emb (query , image_rotary_emb )
1866+             key  =  apply_rotary_emb (key , image_rotary_emb )
1867+ 
1868+         if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
1869+             hidden_states  =  torch_npu .npu_fusion_attention (
1870+                 query ,
1871+                 key ,
1872+                 value ,
1873+                 attn .heads ,
1874+                 input_layout = "BNSD" ,
1875+                 pse = None ,
1876+                 scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
1877+                 pre_tockens = 65536 ,
1878+                 next_tockens = 65536 ,
1879+                 keep_prob = 1.0 ,
1880+                 sync = False ,
1881+                 inner_precise = 0 ,
1882+             )[0 ]
1883+         else :
1884+             hidden_states  =  F .scaled_dot_product_attention (
1885+                 query , key , value , dropout_p = 0.0 , is_causal = False 
1886+             )
1887+         hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
1888+         hidden_states  =  hidden_states .to (query .dtype )
1889+ 
1890+         if  encoder_hidden_states  is  not None :
1891+             encoder_hidden_states , hidden_states  =  (
1892+                 hidden_states [:, : encoder_hidden_states .shape [1 ]],
1893+                 hidden_states [:, encoder_hidden_states .shape [1 ] :],
1894+             )
1895+ 
1896+             # linear proj 
1897+             hidden_states  =  attn .to_out [0 ](hidden_states )
1898+             # dropout 
1899+             hidden_states  =  attn .to_out [1 ](hidden_states )
1900+             encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
1901+ 
1902+             return  hidden_states , encoder_hidden_states 
1903+         else :
1904+             return  hidden_states 
18221905
18231906class  FusedFluxAttnProcessor2_0 :
18241907    """Attention processor used typically in processing the SD3-like self-attention projections.""" 
@@ -1893,28 +1976,7 @@ def __call__(
18931976            query  =  apply_rotary_emb (query , image_rotary_emb )
18941977            key  =  apply_rotary_emb (key , image_rotary_emb )
18951978
1896-         if  is_torch_npu_available ():
1897-             if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
1898-                 hidden_states  =  torch_npu .npu_fusion_attention (
1899-                     query ,
1900-                     key ,
1901-                     value ,
1902-                     attn .heads ,
1903-                     input_layout = "BNSD" ,
1904-                     pse = None ,
1905-                     scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
1906-                     pre_tockens = 65536 ,
1907-                     next_tockens = 65536 ,
1908-                     keep_prob = 1.0 ,
1909-                     sync = False ,
1910-                     inner_precise = 0 ,
1911-                 )[0 ]
1912-             else :
1913-                 hidden_states  =  F .scaled_dot_product_attention (
1914-                     query , key , value , dropout_p = 0.0 , is_causal = False 
1915-                 )
1916-         else :
1917-             hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1979+         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
19181980        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
19191981        hidden_states  =  hidden_states .to (query .dtype )
19201982
@@ -1934,6 +1996,117 @@ def __call__(
19341996        else :
19351997            return  hidden_states 
19361998
1999+ class  FusedFluxAttnProcessor2_0_NPU :
2000+     """Attention processor used typically in processing the SD3-like self-attention projections.""" 
2001+ 
2002+     def  __init__ (self ):
2003+         if  not  hasattr (F , "scaled_dot_product_attention" ):
2004+             raise  ImportError (
2005+                 "FusedFluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0." 
2006+             )
2007+ 
2008+     def  __call__ (
2009+         self ,
2010+         attn : Attention ,
2011+         hidden_states : torch .FloatTensor ,
2012+         encoder_hidden_states : torch .FloatTensor  =  None ,
2013+         attention_mask : Optional [torch .FloatTensor ] =  None ,
2014+         image_rotary_emb : Optional [torch .Tensor ] =  None ,
2015+     ) ->  torch .FloatTensor :
2016+         batch_size , _ , _  =  hidden_states .shape  if  encoder_hidden_states  is  None  else  encoder_hidden_states .shape 
2017+ 
2018+         # `sample` projections. 
2019+         qkv  =  attn .to_qkv (hidden_states )
2020+         split_size  =  qkv .shape [- 1 ] //  3 
2021+         query , key , value  =  torch .split (qkv , split_size , dim = - 1 )
2022+ 
2023+         inner_dim  =  key .shape [- 1 ]
2024+         head_dim  =  inner_dim  //  attn .heads 
2025+ 
2026+         query  =  query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2027+         key  =  key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2028+         value  =  value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2029+ 
2030+         if  attn .norm_q  is  not None :
2031+             query  =  attn .norm_q (query )
2032+         if  attn .norm_k  is  not None :
2033+             key  =  attn .norm_k (key )
2034+ 
2035+         # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 
2036+         # `context` projections. 
2037+         if  encoder_hidden_states  is  not None :
2038+             encoder_qkv  =  attn .to_added_qkv (encoder_hidden_states )
2039+             split_size  =  encoder_qkv .shape [- 1 ] //  3 
2040+             (
2041+                 encoder_hidden_states_query_proj ,
2042+                 encoder_hidden_states_key_proj ,
2043+                 encoder_hidden_states_value_proj ,
2044+             ) =  torch .split (encoder_qkv , split_size , dim = - 1 )
2045+ 
2046+             encoder_hidden_states_query_proj  =  encoder_hidden_states_query_proj .view (
2047+                 batch_size , - 1 , attn .heads , head_dim 
2048+             ).transpose (1 , 2 )
2049+             encoder_hidden_states_key_proj  =  encoder_hidden_states_key_proj .view (
2050+                 batch_size , - 1 , attn .heads , head_dim 
2051+             ).transpose (1 , 2 )
2052+             encoder_hidden_states_value_proj  =  encoder_hidden_states_value_proj .view (
2053+                 batch_size , - 1 , attn .heads , head_dim 
2054+             ).transpose (1 , 2 )
2055+ 
2056+             if  attn .norm_added_q  is  not None :
2057+                 encoder_hidden_states_query_proj  =  attn .norm_added_q (encoder_hidden_states_query_proj )
2058+             if  attn .norm_added_k  is  not None :
2059+                 encoder_hidden_states_key_proj  =  attn .norm_added_k (encoder_hidden_states_key_proj )
2060+ 
2061+             # attention 
2062+             query  =  torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
2063+             key  =  torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
2064+             value  =  torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
2065+ 
2066+         if  image_rotary_emb  is  not None :
2067+             from  .embeddings  import  apply_rotary_emb 
2068+ 
2069+             query  =  apply_rotary_emb (query , image_rotary_emb )
2070+             key  =  apply_rotary_emb (key , image_rotary_emb )
2071+ 
2072+         if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
2073+             hidden_states  =  torch_npu .npu_fusion_attention (
2074+                 query ,
2075+                 key ,
2076+                 value ,
2077+                 attn .heads ,
2078+                 input_layout = "BNSD" ,
2079+                 pse = None ,
2080+                 scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
2081+                 pre_tockens = 65536 ,
2082+                 next_tockens = 65536 ,
2083+                 keep_prob = 1.0 ,
2084+                 sync = False ,
2085+                 inner_precise = 0 ,
2086+             )[0 ]
2087+         else :
2088+             hidden_states  =  F .scaled_dot_product_attention (
2089+                 query , key , value , dropout_p = 0.0 , is_causal = False 
2090+             )
2091+ 
2092+         hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
2093+         hidden_states  =  hidden_states .to (query .dtype )
2094+ 
2095+         if  encoder_hidden_states  is  not None :
2096+             encoder_hidden_states , hidden_states  =  (
2097+                 hidden_states [:, : encoder_hidden_states .shape [1 ]],
2098+                 hidden_states [:, encoder_hidden_states .shape [1 ] :],
2099+             )
2100+ 
2101+             # linear proj 
2102+             hidden_states  =  attn .to_out [0 ](hidden_states )
2103+             # dropout 
2104+             hidden_states  =  attn .to_out [1 ](hidden_states )
2105+             encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
2106+ 
2107+             return  hidden_states , encoder_hidden_states 
2108+         else :
2109+             return  hidden_states 
19372110
19382111class  CogVideoXAttnProcessor2_0 :
19392112    r""" 
0 commit comments