@@ -1778,7 +1778,28 @@ def __call__(
17781778            query  =  apply_rotary_emb (query , image_rotary_emb )
17791779            key  =  apply_rotary_emb (key , image_rotary_emb )
17801780
1781-         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1781+         if  is_torch_npu_available ():
1782+             if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
1783+                 hidden_states  =  torch_npu .npu_fusion_attention (
1784+                     query ,
1785+                     key ,
1786+                     value ,
1787+                     attn .heads ,
1788+                     input_layout = "BNSD" ,
1789+                     pse = None ,
1790+                     scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
1791+                     pre_tockens = 65536 ,
1792+                     next_tockens = 65536 ,
1793+                     keep_prob = 1.0 ,
1794+                     sync = False ,
1795+                     inner_precise = 0 ,
1796+                 )[0 ]
1797+             else :
1798+                 hidden_states  =  F .scaled_dot_product_attention (
1799+                     query , key , value , dropout_p = 0.0 , is_causal = False 
1800+                 )
1801+         else :
1802+             hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
17821803        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
17831804        hidden_states  =  hidden_states .to (query .dtype )
17841805
@@ -1872,7 +1893,28 @@ def __call__(
18721893            query  =  apply_rotary_emb (query , image_rotary_emb )
18731894            key  =  apply_rotary_emb (key , image_rotary_emb )
18741895
1875-         hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1896+         if  is_torch_npu_available ():
1897+             if  query .dtype  in  (torch .float16 , torch .bfloat16 ):
1898+                 hidden_states  =  torch_npu .npu_fusion_attention (
1899+                     query ,
1900+                     key ,
1901+                     value ,
1902+                     attn .heads ,
1903+                     input_layout = "BNSD" ,
1904+                     pse = None ,
1905+                     scale = 1.0  /  math .sqrt (query .shape [- 1 ]),
1906+                     pre_tockens = 65536 ,
1907+                     next_tockens = 65536 ,
1908+                     keep_prob = 1.0 ,
1909+                     sync = False ,
1910+                     inner_precise = 0 ,
1911+                 )[0 ]
1912+             else :
1913+                 hidden_states  =  F .scaled_dot_product_attention (
1914+                     query , key , value , dropout_p = 0.0 , is_causal = False 
1915+                 )
1916+         else :
1917+             hidden_states  =  F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
18761918        hidden_states  =  hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads  *  head_dim )
18771919        hidden_states  =  hidden_states .to (query .dtype )
18781920
0 commit comments