diff --git a/include/flashinfer/attention/persistent.cuh b/include/flashinfer/attention/persistent.cuh index 68a692960..fae896aa1 100644 --- a/include/flashinfer/attention/persistent.cuh +++ b/include/flashinfer/attention/persistent.cuh @@ -138,6 +138,37 @@ __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][ } } +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { + using AttentionVariant = typename KTraits::AttentionVariant; + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][2]; + // compute reciprocal of d +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) + ? math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1]; + } + } + } + } +} + template struct BlockBatchPagedAttentionPersistent { using KTraits = KTraits_;