Skip to content

Commit 6cab4e7

Browse files
authored
[Fix] Fixed FlashMaskV3 arch>=90 compilation bug (PaddlePaddle#76227)
1 parent 4074f96 commit 6cab4e7

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,13 +1383,14 @@ void FlashMaskV2GradBaseKernel(
13831383
params_handle,
13841384
head_size); // We don't support hdim_v being
13851385
// different from hdim_qk for now
1386+
DenseTensor tile_count_semaphore;
13861387
if (arch >= 90) {
1387-
DenseTensor tile_count_semaphore =
1388-
phi::Full<int32_t, Context>(dev_ctx, {1}, 0);
1388+
tile_count_semaphore = phi::Full<int32_t, Context>(dev_ctx, {1}, 0);
13891389
phi::dynload::flashmaskv2_bwd_params_set_tile_count_semaphore(
1390-
tile_count_semaphore.data<int>());
1390+
params_handle, tile_count_semaphore.data<int>());
13911391
} else {
1392-
phi::dynload::flashmaskv2_bwd_params_set_tile_count_semaphore(nullptr);
1392+
phi::dynload::flashmaskv2_bwd_params_set_tile_count_semaphore(params_handle,
1393+
nullptr);
13931394
}
13941395

13951396
DenseTensor dq_semaphore = phi::Empty<int32_t>(

0 commit comments

Comments
 (0)