Skip to content

Commit 1ce2be0

Browse files
authored
Fix parameter name in flash_fn call
1 parent f7a7bbb commit 1ce2be0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _flash_dynamic_mask_attention_forward(
129129
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
130130

131131
out = flash_fn(
132-
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal
132+
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, softmax_scale=softmax_scale, is_causal=is_causal
133133
)
134134

135135
return out[0] if isinstance(out, tuple) else out

0 commit comments

Comments
 (0)