@@ -1242,6 +1242,72 @@ def outer_reduce(x):
12421242        self .assertEqual (outer_reduce (a ), out )
12431243        self .assertTrue ("for roffset"  not  in code )
12441244
1245+     def  test_scaled_dot_product_efficient_attention_backward (self ):
1246+         from  torch  import  nn , Tensor 
1247+ 
1248+         class  SelfAttention (nn .Module ):
1249+             def  __init__ (
1250+                 self ,
1251+                 num_attention_heads : int  =  12 ,
1252+                 hidden_size : int  =  768 ,
1253+                 attention_probs_dropout_prob : float  =  0.1 ,
1254+             ):
1255+                 super ().__init__ ()
1256+ 
1257+                 self .num_attention_heads  =  num_attention_heads 
1258+                 self .attention_head_size  =  hidden_size  //  num_attention_heads 
1259+ 
1260+                 self .query  =  nn .Linear (hidden_size , hidden_size )
1261+                 self .key  =  nn .Linear (hidden_size , hidden_size )
1262+                 self .value  =  nn .Linear (hidden_size , hidden_size )
1263+ 
1264+                 self .dropout_prob  =  attention_probs_dropout_prob 
1265+ 
1266+             def  transpose_for_scores (self , x : Tensor ) ->  Tensor :
1267+                 new_x_shape  =  x .size ()[:- 1 ] +  (
1268+                     self .num_attention_heads ,
1269+                     self .attention_head_size ,
1270+                 )
1271+                 return  x .view (new_x_shape ).permute (0 , 2 , 1 , 3 )
1272+ 
1273+             def  forward (self , hidden_states : Tensor , attention_mask : Tensor ) ->  Tensor :
1274+                 query_layer  =  self .transpose_for_scores (self .query (hidden_states ))
1275+                 key_layer  =  self .transpose_for_scores (self .key (hidden_states ))
1276+                 value_layer  =  self .transpose_for_scores (self .value (hidden_states ))
1277+ 
1278+                 attn_output  =  torch .nn .functional .scaled_dot_product_attention (
1279+                     query_layer ,
1280+                     key_layer ,
1281+                     value_layer ,
1282+                     attn_mask = attention_mask ,
1283+                     dropout_p = self .dropout_prob  if  self .training  else  0.0 ,
1284+                     is_causal = False ,
1285+                 )
1286+                 return  attn_output 
1287+ 
1288+         device  =  torch .device ("cuda" )
1289+         num_attention_heads  =  8 
1290+         hidden_size  =  512 
1291+         attention_probs_dropout_prob  =  0.0 
1292+         model  =  SelfAttention (
1293+             num_attention_heads = num_attention_heads ,
1294+             hidden_size = hidden_size ,
1295+             attention_probs_dropout_prob = attention_probs_dropout_prob ,
1296+         ).to (device )
1297+ 
1298+         model  =  torch .compile (model )
1299+ 
1300+         # runs without failure 
1301+         batch_size  =  8 
1302+         length  =  1 
1303+         inputs_embeds  =  torch .randn (batch_size , length , hidden_size , device = device )
1304+         attention_mask  =  torch .ones (batch_size , 1 , length , length , device = device )
1305+         attn_output  =  model (hidden_states = inputs_embeds , attention_mask = attention_mask )[
1306+             0 
1307+         ]
1308+         loss  =  attn_output .mean ()
1309+         loss .backward ()
1310+ 
12451311    def  test_non_contiguous_unaligned_input_indices (self ):
12461312        from  torch ._inductor .compile_fx  import  remove_unaligned_input_idxs 
12471313
0 commit comments