From 905faadeec701fad3469251964c46b23109dbb8a Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 12:54:32 -0400 Subject: [PATCH 1/6] init --- csrc/trtllm_fmha_kernel_launcher.cu | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 818d34ccca..35c733a92b 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -36,8 +36,11 @@ enum class TllmPagedAttentionMode { ForGen, }; -#include -#include +// 128MB: max workspace buffer size for trtllm-gen fixed at python api side +constexpr size_t kMaxWorkspaceBufferSize = 128 * 1024 * 1024; + +// #include +// #include class TllmGenFmhaRunnerCache { public: @@ -152,8 +155,9 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - runner_params.multiCtasKvScratchPtr = reinterpret_cast( - static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t)); + runner_params.multiCtasKvScratchPtr = + reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - + num_semaphores * sizeof(uint32_t)); runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); } @@ -380,9 +384,10 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - runner_params.multiCtasKvScratchPtr = reinterpret_cast( - static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t) + - sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); + // semaphores be at the last 8MB of 128 MB workspace buffer + runner_params.multiCtasKvScratchPtr = + reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - + num_semaphores * sizeof(uint32_t)); runner_params.multiCtasKvCounterPtr = reinterpret_cast(static_cast(workspace_buffer) + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); From 0141f58caa6dd5e014a52d016cee78fa3d155e80 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 12:57:54 -0400 Subject: [PATCH 2/6] upd --- csrc/trtllm_fmha_kernel_launcher.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 35c733a92b..98a1f1a9f4 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -39,8 +39,8 @@ enum class TllmPagedAttentionMode { // 128MB: max workspace buffer size for trtllm-gen fixed at python api side constexpr size_t kMaxWorkspaceBufferSize = 128 * 1024 * 1024; -// #include -// #include +#include +#include class TllmGenFmhaRunnerCache { public: From 15ec60e714360a107fa4a5f0a1379fed6615f4e5 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 16:04:45 -0400 Subject: [PATCH 3/6] upd --- csrc/trtllm_fmha_kernel_launcher.cu | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 98a1f1a9f4..48000c1fab 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -155,10 +155,11 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - runner_params.multiCtasKvScratchPtr = - reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - - num_semaphores * sizeof(uint32_t)); - runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); + // semaphores be at the last 8MB of 128 MB, workspace buffer: counter | scratch + runner_params.multiCtasKvScratchPtr = reinterpret_cast(workspace_buffer); + runner_params.multiCtasKvCounterPtr = + reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - + num_semaphores * sizeof(uint32_t)); } auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); @@ -384,13 +385,13 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the last 8MB of 128 MB workspace buffer + // semaphores be at the last 8MB of 128 MB, workspace buffer: softmax | counter | scratch runner_params.multiCtasKvScratchPtr = - reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - - num_semaphores * sizeof(uint32_t)); + reinterpret_cast(static_cast(workspace_buffer) + + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); runner_params.multiCtasKvCounterPtr = - reinterpret_cast(static_cast(workspace_buffer) + - sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); + reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - + num_semaphores * sizeof(uint32_t)); runner_params.softmaxStatsPtr = reinterpret_cast(workspace_buffer); auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); From 76f27ae1791fdef5d92ee098dcff13881df14d9b Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 16:06:50 -0400 Subject: [PATCH 4/6] upd comment --- csrc/trtllm_fmha_kernel_launcher.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 48000c1fab..8fc31c88b8 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -155,7 +155,7 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the last 8MB of 128 MB, workspace buffer: counter | scratch + // semaphores be at the last 8MB of 128 MB, workspace buffer: scratch | counter runner_params.multiCtasKvScratchPtr = reinterpret_cast(workspace_buffer); runner_params.multiCtasKvCounterPtr = reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - @@ -385,7 +385,7 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the last 8MB of 128 MB, workspace buffer: softmax | counter | scratch + // semaphores be at the last 8MB of 128 MB, workspace buffer: softmax | scratch | counter runner_params.multiCtasKvScratchPtr = reinterpret_cast(static_cast(workspace_buffer) + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); From c8e4dbddd5defc8e12d98ce4f0a08ed4d2464f9b Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 16:31:14 -0400 Subject: [PATCH 5/6] upd --- csrc/trtllm_fmha_kernel_launcher.cu | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 8fc31c88b8..d55efaa90e 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -155,11 +155,10 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the last 8MB of 128 MB, workspace buffer: scratch | counter - runner_params.multiCtasKvScratchPtr = reinterpret_cast(workspace_buffer); - runner_params.multiCtasKvCounterPtr = - reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - - num_semaphores * sizeof(uint32_t)); + // semaphores be at the first 8MB of 128 MB, workspace buffer: counter | scratch + runner_params.multiCtasKvScratchPtr = reinterpret_cast( + static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t)); + runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); } auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); @@ -385,14 +384,13 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the last 8MB of 128 MB, workspace buffer: softmax | scratch | counter - runner_params.multiCtasKvScratchPtr = - reinterpret_cast(static_cast(workspace_buffer) + - sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); - runner_params.multiCtasKvCounterPtr = - reinterpret_cast(static_cast(workspace_buffer) + kMaxWorkspaceBufferSize - - num_semaphores * sizeof(uint32_t)); - runner_params.softmaxStatsPtr = reinterpret_cast(workspace_buffer); + // semaphores be at the first 8MB of 128 MB, workspace buffer: counter | softmax | scratch + runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); + runner_params.softmaxStatsPtr = reinterpret_cast(static_cast(workspace_buffer) + + num_semaphores * sizeof(uint32_t)); + runner_params.multiCtasKvScratchPtr = reinterpret_cast( + static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t) + + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ); auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); if (!foundKernels) { From 46bbe8841fbcdab7e1d140d2e0d4012baa9dabec Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Tue, 26 Aug 2025 16:32:23 -0400 Subject: [PATCH 6/6] cleanup --- csrc/trtllm_fmha_kernel_launcher.cu | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index d55efaa90e..0164bbc109 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -36,9 +36,6 @@ enum class TllmPagedAttentionMode { ForGen, }; -// 128MB: max workspace buffer size for trtllm-gen fixed at python api side -constexpr size_t kMaxWorkspaceBufferSize = 128 * 1024 * 1024; - #include #include @@ -155,7 +152,7 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the first 8MB of 128 MB, workspace buffer: counter | scratch + // semaphores be at the first 8MB of workspace buffer: counter | scratch runner_params.multiCtasKvScratchPtr = reinterpret_cast( static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t)); runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); @@ -384,7 +381,7 @@ void trtllm_ragged_attention_launcher( size_t max_num_qo_heads = 256; size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the first 8MB of 128 MB, workspace buffer: counter | softmax | scratch + // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch runner_params.multiCtasKvCounterPtr = reinterpret_cast(workspace_buffer); runner_params.softmaxStatsPtr = reinterpret_cast(static_cast(workspace_buffer) + num_semaphores * sizeof(uint32_t));