@@ -183,7 +183,7 @@ def scaled_dot_product_attention_backward(
183183 key_states , # [batch, num_kv_heads, key_len, head_dim]
184184 value_states , # [batch, num_kv_heads, key_len, head_dim]
185185 attn_mask = causal_mask ,
186- scale = scaling ,
186+ softmax_scale = scaling ,
187187 # is_causal=is_causal if query_len == key_len else False,
188188 enable_gqa = True
189189 )
@@ -262,7 +262,7 @@ def dynamic_mask_attention_backward_cuda(
262262 attn_mask = attn_mask , # mask: [batch, num_kv_heads, query_len, key_len]
263263 attn_bias = attn_bias , # bias: [batch, num_kv_heads, query_len, key_len]
264264 is_causal = is_causal , # causal masking
265- scale = scaling , # scaling factor
265+ softmax_scale = scaling , # scaling factor
266266 softcap = 0.0 ,
267267 deterministic = False ,
268268 return_attn_probs = False
@@ -351,7 +351,7 @@ def dynamic_mask_attention_backward_triton(
351351 attn_mask = attn_mask , # mask: [batch, num_heads, seqlen_q, seqlen_k]
352352 attn_bias = attn_bias , # bias: [batch, num_heads, seqlen_q, seqlen_k]
353353 is_causal = is_causal , # causal masking
354- scale = scaling # scaling factor
354+ softmax_scale = scaling # scaling factor
355355 )
356356
357357 torch .cuda .synchronize ()
@@ -433,7 +433,7 @@ def dynamic_mask_attention_backward_flex(
433433 attn_mask = attn_mask , # attn_mask: [batch, num_heads, query_len, key_len]
434434 attn_bias = attn_bias , # attn_bias: [batch, num_heads, query_len, key_len]
435435 is_causal = is_causal , # is_causal: whether to apply causal masking
436- scale = scaling # scaling factor
436+ softmax_scale = scaling # scaling factor
437437 )
438438
439439 torch .cuda .synchronize ()
0 commit comments