Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void trtllm_paged_attention_launcher(
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// semaphores be at the first 8MB of workspace buffer: counter | scratch
runner_params.multiCtasKvScratchPtr = reinterpret_cast<void*>(
static_cast<char*>(workspace_buffer) + num_semaphores * sizeof(uint32_t));
runner_params.multiCtasKvCounterPtr = reinterpret_cast<int32_t*>(workspace_buffer);
Expand Down Expand Up @@ -380,13 +381,13 @@ void trtllm_ragged_attention_launcher(
size_t max_num_qo_heads = 256;
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch
runner_params.multiCtasKvCounterPtr = reinterpret_cast<int32_t*>(workspace_buffer);
runner_params.softmaxStatsPtr = reinterpret_cast<float2*>(static_cast<char*>(workspace_buffer) +
num_semaphores * sizeof(uint32_t));
runner_params.multiCtasKvScratchPtr = reinterpret_cast<void*>(
static_cast<char*>(workspace_buffer) + num_semaphores * sizeof(uint32_t) +
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ);
runner_params.multiCtasKvCounterPtr =
reinterpret_cast<int32_t*>(static_cast<char*>(workspace_buffer) +
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ);
runner_params.softmaxStatsPtr = reinterpret_cast<float2*>(workspace_buffer);

auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);
if (!foundKernels) {
Expand Down