@@ -297,6 +297,28 @@ def _sfdp_replacement_7(query, key, value, attention_mask, scaling, dropout):
297297 )
298298
299299
300+ # with causal_mask, no division, does not cast to fp32 for softmax
301+ def _sfdp_pattern_8 (query , key , value , attention_mask , scaling , dropout ):
302+ attn_weights = torch .matmul (query , key .transpose (2 , 3 )) * scaling
303+ attn_weights = attn_weights + attention_mask
304+ attn_weights = F .softmax (attn_weights , dim = - 1 )
305+ attn_weights = F .dropout (attn_weights , p = dropout , training = False )
306+ attn_output = torch .matmul (attn_weights , value )
307+ return attn_output
308+
309+
310+ def _sfdp_replacement_8 (query , key , value , attention_mask , scaling , dropout ):
311+ return torch .ops .auto_deploy .torch_attention_sdpa .default (
312+ query ,
313+ key ,
314+ value ,
315+ attn_mask = None ,
316+ dropout_p = dropout ,
317+ is_causal = True ,
318+ scale = scaling ,
319+ )
320+
321+
300322def _get_sfdp_patterns () -> List [Dict [str , Any ]]:
301323 bs , seq_len , n_heads , hidden_size = 8 , 16 , 8 , 512
302324 head_dim = hidden_size // n_heads
@@ -315,6 +337,7 @@ def causal_mask():
315337 (_sfdp_pattern_5 , _sfdp_replacement_5 , False , 0.874321 , 0.89734 ),
316338 (_sfdp_pattern_6 , _sfdp_replacement_6 , True , 0.634743 , 0.6849734 ),
317339 (_sfdp_pattern_7 , _sfdp_replacement_7 , True , 0.34743 , 0.849734 ),
340+ (_sfdp_pattern_8 , _sfdp_replacement_8 , True , 0.2234743 , 0.95849734 ),
318341 ]
319342
320343 patterns = []
0 commit comments