diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 818d34ccca..0164bbc109 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -152,6 +152,7 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes + // semaphores be at the first 8MB of workspace buffer: counter | scratch runner_params.multiCtasKvScratchPtr = reinterpret_cast( static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t)); runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); @@ -380,13 +381,13 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes + // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch + runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); + runner_params.softmaxStatsPtr = reinterpret_cast(static_cast(workspace_buffer) + + num_semaphores * sizeof(uint32_t)); runner_params.multiCtasKvScratchPtr = reinterpret_cast( static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t) + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); - runner_params.multiCtasKvCounterPtr = - reinterpret_cast(static_cast(workspace_buffer) + - sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); - runner_params.softmaxStatsPtr = reinterpret_cast(workspace_buffer); auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); if (!foundKernels) {