From 58a5a52952ab525532a91c5f5052f25a83fb7840 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 8 Aug 2025 18:14:48 +0000 Subject: [PATCH 1/2] fix --- include/flashinfer/attention/persistent.cuh | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 68a692960..28129ed89 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -138,6 +138,37 @@ __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][ } } +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { + using AttentionVariant = typename KTraits::AttentionVariant; + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][2]; + // compute reciprocal of d +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) + ? math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; + } + } + } + } +} + template struct BlockBatchPagedAttentionPersistent { using KTraits = KTraits_; From 2b853bf829e9061e42c2d04bd9f60e7c6e982e35 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 8 Aug 2025 20:55:36 +0000 Subject: [PATCH 2/2] use bitwise --- include/flashinfer/attention/persistent.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 28129ed89..fae896aa1 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -162,7 +162,7 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_V #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = - o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1]; } } }