Skip to content

Commit 5dbecdc

Browse files
authored
[XPU] Update XHPC to 20251014 and add some dim check in FlashAttnKernel of xpu. (PaddlePaddle#75872)
1 parent fd95aba commit 5dbecdc

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

cmake/external/xpu.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ set(XPU_FFT_LIB_NAME "libcufft.so")
3434
add_compile_definitions(XPUAPI_NOT_INCLUDE_DEPRECATED)
3535

3636
if(NOT DEFINED XPU_XHPC_BASE_DATE)
37-
set(XPU_XHPC_BASE_DATE "dev/20251010")
37+
set(XPU_XHPC_BASE_DATE "dev/20251014")
3838
endif()
39-
set(XPU_XCCL_BASE_VERSION "3.0.3.3") # For XRE5
39+
set(XPU_XCCL_BASE_VERSION "3.0.3.4") # For XRE5
4040
if(NOT DEFINED XPU_XFT_BASE_VERSION)
4141
set(XPU_XFT_BASE_VERSION "20250507/xpu3")
4242
endif()

paddle/phi/kernels/xpu/flash_attn_kernel.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,21 @@ void FlashAttnKernel(const Context& dev_ctx,
462462
common::errors::InvalidArgument(
463463
"flash_attn receive input with dim "
464464
"[batch_size, seq_len, num_heads, head_dim]"));
465+
PADDLE_ENFORCE_EQ(k.dims().size(),
466+
4,
467+
common::errors::InvalidArgument(
468+
"flash_attn receive input with dim "
469+
"[batch_size, seq_len, num_heads, head_dim]"));
470+
PADDLE_ENFORCE_EQ(v.dims().size(),
471+
4,
472+
common::errors::InvalidArgument(
473+
"flash_attn receive input with dim "
474+
"[batch_size, seq_len, num_heads, head_dim]"));
475+
PADDLE_ENFORCE_EQ(out->dims().size(),
476+
4,
477+
common::errors::InvalidArgument(
478+
"flash_attn receive input with dim "
479+
"[batch_size, seq_len, num_heads, head_dim]"));
465480

466481
const int64_t batch_size = dims[0];
467482
const int64_t seqlen_q = dims[1];

0 commit comments

Comments
 (0)