Skip to content

Commit 920c82a

Browse files
authored
Remove cudaMalloc/Free in GDN prefill kernel (#2415)
<!-- .github/pull_request_template.md --> ## 📌 Description In GDN prefill kernel, this line of code will cause redundant cudaMalloc/cudaFree in kernel execution, which harms runtime performance. This workspace buffer is used for TMA store output. https://github.com/flashinfer-ai/flashinfer/blob/a49b45336e56e4615eae102cf29d5110293d9130/csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh#L132 This PR replaces this with a workspace buffer created by torch with the same size (# of SMs * 128B), removing redundant cudaMalloc/Free function call. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Launchers and kernels now accept an external per‑SM workspace buffer; internal workspace allocation removed. * Native launchers and Python prefill functions updated to accept, validate, and forward the workspace buffer. * Runtime checks added for the provided workspace; call sites updated to construct and pass a per‑SM workspace where needed. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c8d76d3 commit 920c82a

File tree

5 files changed

+49
-36
lines changed

5 files changed

+49
-36
lines changed

csrc/flat/prefill/prefill_kernel.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
3232
TQKV const* q, TQKV const* k, TQKV const* v,
3333
TState const* input_state, float const* alpha,
3434
float const* beta, int64_t const* cu_seqlens,
35-
int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads,
36-
int32_t num_v_heads, int32_t num_o_heads, int32_t head_size,
37-
int64_t total_seqlen, float scale, int32_t sm_count = 0);
35+
uint8_t* workspace_buffer, int32_t num_seqs,
36+
int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads,
37+
int32_t num_o_heads, int32_t head_size, int64_t total_seqlen,
38+
float scale, int32_t sm_count = 0);
3839

3940
} // namespace flat

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,20 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
2727
TQKV const* q, TQKV const* k, TQKV const* v,
2828
TState const* input_state, float const* alpha,
2929
float const* beta, int64_t const* cu_seqlens,
30-
int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads,
31-
int32_t num_v_heads, int32_t num_o_heads, int32_t head_size,
32-
int64_t total_seqlen, float scale, int32_t sm_count) {
30+
uint8_t* workspace_buffer, int32_t num_seqs,
31+
int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads,
32+
int32_t num_o_heads, int32_t head_size, int64_t total_seqlen,
33+
float scale, int32_t sm_count) {
3334
bool is_gva = num_v_heads > num_q_heads;
3435
bool needs_beta = beta != nullptr;
3536
bool needs_alpha = alpha != nullptr;
3637
bool init_state = input_state != nullptr;
3738

38-
#define LAUNCH(is_gva, needs_beta, needs_alpha, init_state) \
39-
launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, ArchTag>( \
40-
stream, output, output_state, q, k, v, input_state, alpha, beta, cu_seqlens, num_seqs, \
41-
num_q_heads, num_k_heads, num_v_heads, num_o_heads, head_size, total_seqlen, scale, \
42-
sm_count);
39+
#define LAUNCH(is_gva, needs_beta, needs_alpha, init_state) \
40+
launch_delta_rule_prefill_kernel_gbai<is_gva, needs_beta, needs_alpha, init_state, ArchTag>( \
41+
stream, output, output_state, q, k, v, input_state, alpha, beta, cu_seqlens, \
42+
workspace_buffer, num_seqs, num_q_heads, num_k_heads, num_v_heads, num_o_heads, head_size, \
43+
total_seqlen, scale, sm_count);
4344

4445
if (init_state) {
4546
if (is_gva && needs_beta && needs_alpha) {
@@ -89,15 +90,16 @@ void launch_delta_rule_prefill_kernel(cudaStream_t stream, TO* output, TState* o
8990
template void launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, half, half, float>(
9091
cudaStream_t stream, half* output, float* state, half const* q, half const* k, half const* v,
9192
float const* input_state, float const* alpha, float const* beta, int64_t const* cu_seqlens,
92-
int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads,
93-
int32_t num_o_heads, int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count);
93+
uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads, int32_t num_k_heads,
94+
int32_t num_v_heads, int32_t num_o_heads, int32_t head_size, int64_t total_seqlen, float scale,
95+
int32_t sm_count);
9496

9597
template void
9698
launch_delta_rule_prefill_kernel<cutlass::arch::Sm90, nv_bfloat16, nv_bfloat16, float>(
9799
cudaStream_t stream, nv_bfloat16* output, float* state, nv_bfloat16 const* q,
98100
nv_bfloat16 const* k, nv_bfloat16 const* v, float const* input_state, float const* alpha,
99-
float const* beta, int64_t const* cu_seqlens, int32_t num_seqs, int32_t num_q_heads,
100-
int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size,
101-
int64_t total_seqlen, float scale, int32_t sm_count);
101+
float const* beta, int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs,
102+
int32_t num_q_heads, int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads,
103+
int32_t head_size, int64_t total_seqlen, float scale, int32_t sm_count);
102104

103105
} // namespace flat

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,12 @@ using namespace cute;
3131

3232
template <bool IsGVA, bool NeedsBeta, bool NeedsAlpha, bool InitStateFromInput, typename ArchTag,
3333
typename TO, typename TQKV, typename TState>
34-
void launch_delta_rule_prefill_kernel_gbai(cudaStream_t stream, TO* output, TState* output_state,
35-
TQKV const* q, TQKV const* k, TQKV const* v,
36-
TState const* input_state, float const* alpha,
37-
float const* beta, int64_t const* cu_seqlens,
38-
int32_t num_seqs, int32_t num_q_heads,
39-
int32_t num_k_heads, int32_t num_v_heads,
40-
int32_t num_o_heads, int32_t head_size,
41-
int64_t total_seqlen, float scale, int32_t sm_count) {
34+
void launch_delta_rule_prefill_kernel_gbai(
35+
cudaStream_t stream, TO* output, TState* output_state, TQKV const* q, TQKV const* k,
36+
TQKV const* v, TState const* input_state, float const* alpha, float const* beta,
37+
int64_t const* cu_seqlens, uint8_t* workspace_buffer, int32_t num_seqs, int32_t num_q_heads,
38+
int32_t num_k_heads, int32_t num_v_heads, int32_t num_o_heads, int32_t head_size,
39+
int64_t total_seqlen, float scale, int32_t sm_count) {
4240
#if defined(FLAT_SM90A_ENABLED)
4341
constexpr bool HopperSupported = true;
4442
#else
@@ -128,16 +126,13 @@ void launch_delta_rule_prefill_kernel_gbai(cudaStream_t stream, TO* output, TSta
128126
}, // clang-format on
129127
.hw_info = hw_info};
130128

131-
size_t workspace_size = op.get_workspace_size(arguments);
132-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
133-
134129
cutlass::Status status;
135130
status = op.can_implement(arguments);
136131
if (status != cutlass::Status::kSuccess) {
137132
throw std::runtime_error("can_implement failed");
138133
}
139134

140-
status = op.initialize(arguments, workspace.get(), stream);
135+
status = op.initialize(arguments, workspace_buffer, stream);
141136
if (status != cutlass::Status::kSuccess) {
142137
throw std::runtime_error("initialize failed");
143138
}

csrc/gdn_prefill_launcher.cu

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ namespace flashinfer {
3535

3636
void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, void* v,
3737
void* input_state, void* alpha, void* beta, int64_t* cu_seqlens,
38-
int64_t num_seqs, int64_t num_q_heads, int64_t num_k_heads,
39-
int64_t num_v_heads, int64_t num_o_heads, int64_t head_size,
40-
int64_t packed_seq, float scale, int64_t sm_count, DLDataType dtype,
41-
cudaStream_t stream) {
38+
uint8_t* workspace_buffer, int64_t num_seqs, int64_t num_q_heads,
39+
int64_t num_k_heads, int64_t num_v_heads, int64_t num_o_heads,
40+
int64_t head_size, int64_t packed_seq, float scale, int64_t sm_count,
41+
DLDataType dtype, cudaStream_t stream) {
4242
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(dtype, DType, [&] {
4343
int dev_id;
4444
cudaGetDevice(&dev_id);
@@ -51,8 +51,8 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo
5151
stream, static_cast<DType*>(output), static_cast<float*>(output_state),
5252
static_cast<DType const*>(q), static_cast<DType const*>(k), static_cast<DType const*>(v),
5353
static_cast<float const*>(input_state), static_cast<float const*>(alpha),
54-
static_cast<float const*>(beta), cu_seqlens, num_seqs, num_q_heads, num_k_heads,
55-
num_v_heads, num_o_heads, head_size, packed_seq, scale, sm_count);
54+
static_cast<float const*>(beta), cu_seqlens, workspace_buffer, num_seqs, num_q_heads,
55+
num_k_heads, num_v_heads, num_o_heads, head_size, packed_seq, scale, sm_count);
5656
return true;
5757
} else {
5858
std::ostringstream err_msg;
@@ -70,7 +70,8 @@ void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, vo
7070

7171
void gdn_prefill(TensorView output, TensorView output_state, TensorView q, TensorView k,
7272
TensorView v, TensorView cu_seqlens, Optional<TensorView> input_state,
73-
Optional<TensorView> alpha, Optional<TensorView> beta, double scale) {
73+
Optional<TensorView> alpha, Optional<TensorView> beta, double scale,
74+
TensorView workspace_buffer) {
7475
int64_t num_seqs = cu_seqlens.size(0) - 1;
7576
int64_t packed_seq = q.size(0);
7677
int64_t head_size = q.size(2);
@@ -109,13 +110,15 @@ void gdn_prefill(TensorView output, TensorView output_state, TensorView q, Tenso
109110
CHECK_INPUT(k);
110111
CHECK_INPUT(v);
111112
CHECK_INPUT(cu_seqlens);
113+
CHECK_INPUT(workspace_buffer);
112114

113115
TVM_FFI_ICHECK(output.dtype() == dl_float16 || output.dtype() == dl_bfloat16);
114116
TVM_FFI_ICHECK_EQ(output_state.dtype(), dl_float32);
115117
TVM_FFI_ICHECK_EQ(output.dtype(), q.dtype());
116118
TVM_FFI_ICHECK_EQ(output.dtype(), k.dtype());
117119
TVM_FFI_ICHECK_EQ(output.dtype(), v.dtype());
118120
TVM_FFI_ICHECK_EQ(cu_seqlens.dtype(), dl_int64);
121+
TVM_FFI_ICHECK_EQ(workspace_buffer.dtype(), dl_uint8);
119122

120123
TVM_FFI_ICHECK_EQ(packed_seq, k.size(0));
121124
TVM_FFI_ICHECK_EQ(packed_seq, v.size(0));
@@ -164,7 +167,8 @@ void gdn_prefill(TensorView output, TensorView output_state, TensorView q, Tenso
164167

165168
gdn_prefill_launcher(output.data_ptr(), output_state.data_ptr(), q.data_ptr(), k.data_ptr(),
166169
v.data_ptr(), input_state_ptr, alpha_ptr, beta_ptr,
167-
static_cast<int64_t*>(cu_seqlens.data_ptr()), num_seqs, num_q_heads,
170+
static_cast<int64_t*>(cu_seqlens.data_ptr()),
171+
static_cast<uint8_t*>(workspace_buffer.data_ptr()), num_seqs, num_q_heads,
168172
num_k_heads, num_v_heads, num_o_heads, head_size, packed_seq,
169173
static_cast<float>(scale), sm_count, q.dtype(), stream);
170174
}

flashinfer/gdn_prefill.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .utils import (
2525
register_custom_op,
2626
register_fake_op,
27+
get_device_sm_count,
28+
_get_cache_buf,
2729
)
2830

2931

@@ -45,6 +47,7 @@ def gdn_prefill(
4547
g: Optional[torch.Tensor],
4648
beta: Optional[torch.Tensor],
4749
scale: float,
50+
workspace_buffer: torch.Tensor,
4851
) -> None:
4952
module.gdn_prefill(
5053
output,
@@ -57,6 +60,7 @@ def gdn_prefill(
5760
g,
5861
beta,
5962
scale,
63+
workspace_buffer,
6064
)
6165

6266
@register_fake_op("flashinfer::gdn_prefill")
@@ -71,6 +75,7 @@ def _fake_gdn_prefill(
7175
g: Optional[torch.Tensor],
7276
beta: Optional[torch.Tensor],
7377
scale: float,
78+
workspace_buffer: torch.Tensor,
7479
) -> None:
7580
pass
7681

@@ -183,6 +188,11 @@ def chunk_gated_delta_rule(
183188
device=q.device,
184189
)
185190

191+
# Prepare workspace buffer for TMA Store in kernel
192+
# 128B tensormap for each SM on Hopper architecture
193+
workspace_size = get_device_sm_count(q.device) * 128
194+
workspace_buffer = _get_cache_buf("gdn_prefill_workspace", workspace_size, q.device)
195+
186196
get_gdn_prefill_module().gdn_prefill(
187197
output,
188198
output_state,
@@ -194,6 +204,7 @@ def chunk_gated_delta_rule(
194204
g,
195205
beta,
196206
scale if scale is not None else 0.0,
207+
workspace_buffer,
197208
)
198209

199210
if output_final_state:

0 commit comments

Comments
 (0)