diff --git a/backends/mlu/kernels/flash_attn_kernel.cc b/backends/mlu/kernels/flash_attn_kernel.cc index d7ea8648577..31d2018001b 100644 --- a/backends/mlu/kernels/flash_attn_kernel.cc +++ b/backends/mlu/kernels/flash_attn_kernel.cc @@ -311,6 +311,13 @@ void FlashAttnUnpaddedGradKernel( auto compute_dtype = CNNL_DTYPE_FLOAT; auto prefer = CNNL_ACTIVATION_HIGH_PRECISION; auto attn_mask_mode = causal ? CNNL_ATTN_MASK_CAUSAL : CNNL_ATTN_MASK_NONE; + if (attn_mask_mode == CNNL_ATTN_MASK_NONE) { + int32_t max_seq = std::max(max_seqlen_q, max_seqlen_k); + cnnlSetFlashAttentionSlidingWindowSize(desc_, max_seq, max_seq, 1); + } else if (attn_mask_mode == CNNL_ATTN_MASK_CAUSAL) { + int32_t max_seq = std::max(max_seqlen_q, max_seqlen_k); + cnnlSetFlashAttentionSlidingWindowSize(desc_, max_seq, 0, 1); + } cnnlSetFlashAttentionBackwardDescriptor(desc_, compute_dtype, prefer, diff --git a/backends/mlu/kernels/funcs/mlu_baseop.cc b/backends/mlu/kernels/funcs/mlu_baseop.cc index 2279fb39bdd..d7703412bd5 100644 --- a/backends/mlu/kernels/funcs/mlu_baseop.cc +++ b/backends/mlu/kernels/funcs/mlu_baseop.cc @@ -2920,37 +2920,41 @@ NormalizeDesc::~NormalizeDesc() { workspace.Resize({static_cast(workspace_size)}); void* workspace_ptr = ctx.Alloc(&workspace, DataType::INT8, workspace_size); PADDLE_ENFORCE_MLU_SUCCESS( - cnnlFlashAttentionBackward(handle, - flash_atten_desc, - diff_out_desc, - diff_out, - q_desc, - q, - k_desc, - k, - v_desc, - v, - fwd_out_desc, - out, - softmax_lse_desc, - softmax_lse, - csq_desc, - cu_seqlens_q, - csk_desc, - cu_seqlens_k, - rng_state, - workspace_ptr, - workspace_size, - diff_query_desc, - dq, - diff_key_desc, - dk, - diff_value_desc, - dv, - /*dropout_mask_desc = */ nullptr, - /*dropout_mask = */ nullptr, - /*softmax_d_desc = */ nullptr, - /*softmax_d = */ nullptr)); + cnnlFlashAttentionBackward_v2(handle, + flash_atten_desc, + diff_out_desc, + diff_out, + q_desc, + q, + k_desc, + k, + v_desc, + v, + fwd_out_desc, + out, + softmax_lse_desc, + softmax_lse, + csq_desc, + cu_seqlens_q, + csk_desc, + cu_seqlens_k, + /*alibi_slope_desc*/ nullptr, + /*alibi_slope*/ nullptr, + /*additive_attn_mask_desc*/ nullptr, + /*additive_attn_mask*/ nullptr, + rng_state, + workspace_ptr, + workspace_size, + diff_query_desc, + dq, + diff_key_desc, + dk, + diff_value_desc, + dv, + /*dropout_mask_desc = */ nullptr, + /*dropout_mask = */ nullptr, + /*softmax_d_desc = */ nullptr, + /*softmax_d = */ nullptr)); PADDLE_ENFORCE_MLU_SUCCESS( cnnlDestroyFlashAttentionDescriptor(flash_atten_desc)); }