Skip to content

Commit c261e97

Browse files
authored
cudnn: Add native cudnn_decode for improved cudnn decode performance (#1283)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR tries to integrate cudnn decode by calling the cudnn kernels directly instead of through the cubin path. Also, enabled nvidia-cudnn-frontend on all platforms ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 1b83168 commit c261e97

File tree

3 files changed

+303
-27
lines changed

3 files changed

+303
-27
lines changed

flashinfer/cudnn/decode.py

Lines changed: 289 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,260 @@
11
import functools
2+
from enum import Enum
23
from typing import Optional
34

45
import torch
56

67
from ..jit import get_cudnn_fmha_gen_module
78

9+
try:
10+
import cudnn
11+
12+
CUDNN_AVAILABLE = True
13+
except ImportError:
14+
cudnn = None
15+
CUDNN_AVAILABLE = False
16+
17+
# Global cudnn handle. need to make it per device in future
18+
_cudnn_handle = None
19+
20+
21+
def _create_cudnn_handle(stream: torch.cuda.Stream):
22+
global _cudnn_handle
23+
if _cudnn_handle is None:
24+
_cudnn_handle = cudnn.create_handle()
25+
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
26+
return _cudnn_handle
27+
28+
29+
# Tensor ids
30+
class UIDs(Enum):
31+
RESERVED_INVALID_UID = 0
32+
33+
Q_UID = 1 # Query tensor
34+
K_UID = 2 # Key cache tensor
35+
V_UID = 3 # Value cache tensor
36+
37+
ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor
38+
ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor
39+
40+
BLOCK_TABLES_UID = 200 # Block tables tensor
41+
BLOCK_TABLES_K_UID = 201 # Block tables tensor for key
42+
BLOCK_TABLES_V_UID = 202 # Block tables tensor for value
43+
44+
RAGGED_Q_UID = 50 # Ragged query tensor
45+
RAGGED_O_UID = 51 # Ragged output tensor
46+
RAGGED_STATS_UID = 52 # Ragged stats tensor
47+
48+
O_UID = 1000 # Output tensor
49+
STATS_UID = 1001 # Stats tensor
50+
51+
52+
def _sdpa_decode_key_fn(
53+
q: torch.Tensor,
54+
k_cache: torch.Tensor,
55+
v_cache: torch.Tensor,
56+
scale: float,
57+
*,
58+
max_sequence_kv: int,
59+
block_size: Optional[int] = 1,
60+
actual_seq_lens_q: Optional[torch.Tensor] = None,
61+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
62+
block_tables: Optional[torch.Tensor] = None,
63+
batch_offsets_q: Optional[torch.Tensor] = None,
64+
batch_offsets_o: Optional[torch.Tensor] = None,
65+
):
66+
return (
67+
"decode",
68+
max_sequence_kv,
69+
tuple(q.shape),
70+
tuple(k_cache.shape),
71+
)
72+
73+
74+
if CUDNN_AVAILABLE:
75+
76+
@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
77+
@cudnn.graph_cache(key_fn=_sdpa_decode_key_fn)
78+
def _build_decode_graph(
79+
q: torch.Tensor,
80+
k_cache: torch.Tensor,
81+
v_cache: torch.Tensor,
82+
scale: float,
83+
*,
84+
max_sequence_kv: int,
85+
block_size: Optional[int] = 1,
86+
actual_seq_lens_q: Optional[torch.Tensor] = None,
87+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
88+
block_tables: Optional[torch.Tensor] = None,
89+
batch_offsets_q: Optional[torch.Tensor] = None,
90+
batch_offsets_o: Optional[torch.Tensor] = None,
91+
):
92+
handle = _create_cudnn_handle(torch.cuda.current_stream())
93+
94+
# WAR: override batch offsets for now, as it leads to a poor performance
95+
batch_offsets_q = None
96+
batch_offsets_o = None
97+
98+
with cudnn.graph(handle) as (g, _):
99+
if q.dim() == 3:
100+
s_qo = 1
101+
b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2]
102+
elif q.dim() == 4:
103+
b, h_qo, s_qo, d_qk = (
104+
q.shape[0],
105+
q.shape[1],
106+
q.shape[2],
107+
q.shape[3],
108+
)
109+
else:
110+
raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}")
111+
112+
assert s_qo == 1, "q must have a sequence length of 1"
113+
assert k_cache.dim() == 4, "k_cache must have 4 dimensions"
114+
115+
h_kv = k_cache.shape[1]
116+
s_kv = max_sequence_kv
117+
d_vo = v_cache.shape[3]
118+
119+
cudnn_q = g.tensor(
120+
name="q",
121+
dim=(b, h_qo, s_qo, d_qk),
122+
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
123+
data_type=cudnn.data_type.BFLOAT16,
124+
)
125+
if batch_offsets_q is not None:
126+
ragged_q = g.tensor_like(batch_offsets_q)
127+
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
128+
cudnn_q.set_ragged_offset(ragged_q)
129+
130+
cudnn_k_cache = g.tensor_like(k_cache)
131+
cudnn_v_cache = g.tensor_like(v_cache)
132+
133+
cudnn_q.set_uid(UIDs.Q_UID.value)
134+
cudnn_k_cache.set_uid(UIDs.K_UID.value)
135+
cudnn_v_cache.set_uid(UIDs.V_UID.value)
136+
137+
if block_tables is not None:
138+
nd_block_tables = block_tables.reshape(
139+
block_tables.shape[0], 1, block_tables.shape[1], 1
140+
)
141+
cudnn_k_block_tables = g.tensor_like(nd_block_tables)
142+
cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value)
143+
144+
cudnn_v_block_tables = g.tensor_like(nd_block_tables)
145+
cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value)
146+
147+
if actual_seq_lens_q is not None:
148+
cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q)
149+
cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value)
150+
151+
if actual_seq_lens_kv is not None:
152+
cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv)
153+
cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value)
154+
cudnn_actual_seq_lens_kv.set_is_pass_by_value(False)
155+
156+
padding_mask = actual_seq_lens_kv is not None
157+
158+
O, _ = g.sdpa(
159+
name="sdpa",
160+
q=cudnn_q,
161+
k=cudnn_k_cache,
162+
v=cudnn_v_cache,
163+
seq_len_q=(
164+
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
165+
),
166+
seq_len_kv=(
167+
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
168+
),
169+
use_padding_mask=padding_mask,
170+
is_inference=True,
171+
attn_scale=scale,
172+
paged_attention_k_table=cudnn_k_block_tables,
173+
paged_attention_v_table=cudnn_v_block_tables,
174+
paged_attention_max_seq_len_kv=max_sequence_kv,
175+
compute_data_type=cudnn.data_type.FLOAT,
176+
)
177+
178+
if batch_offsets_o is not None:
179+
ragged_o = g.tensor_like(batch_offsets_o)
180+
ragged_o.set_uid(UIDs.RAGGED_O_UID.value)
181+
O.set_ragged_offset(ragged_o)
182+
183+
O.set_uid(UIDs.O_UID.value).set_output(True).set_dim(
184+
[b, h_qo, s_qo, d_vo]
185+
).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type(
186+
cudnn.data_type.BFLOAT16
187+
)
188+
189+
tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O]
190+
191+
if actual_seq_lens_q is not None:
192+
tensors_to_return.append(cudnn_actual_seq_lens_q)
193+
if actual_seq_lens_kv is not None:
194+
tensors_to_return.append(cudnn_actual_seq_lens_kv)
195+
196+
return g, tensors_to_return
197+
198+
199+
def _batch_decode_with_kv_cache(
200+
q: torch.Tensor,
201+
k_cache: torch.Tensor,
202+
v_cache: torch.Tensor,
203+
scale: float,
204+
workspace_buffer: torch.Tensor,
205+
*,
206+
max_sequence_kv: int,
207+
actual_seq_lens_q: Optional[torch.Tensor] = None,
208+
actual_seq_lens_kv: Optional[torch.Tensor] = None,
209+
block_tables: Optional[torch.Tensor] = None,
210+
block_size: Optional[int] = 1,
211+
batch_offsets_q: Optional[torch.Tensor] = None,
212+
batch_offsets_o: Optional[torch.Tensor] = None,
213+
batch_offsets_k: Optional[torch.Tensor] = None,
214+
batch_offsets_v: Optional[torch.Tensor] = None,
215+
out: torch.Tensor,
216+
) -> torch.Tensor:
217+
218+
graph, tensors = _build_decode_graph(
219+
q=q,
220+
k_cache=k_cache,
221+
v_cache=v_cache,
222+
scale=scale,
223+
max_sequence_kv=max_sequence_kv,
224+
actual_seq_lens_q=actual_seq_lens_q,
225+
actual_seq_lens_kv=actual_seq_lens_kv,
226+
block_tables=block_tables,
227+
block_size=block_size,
228+
batch_offsets_q=batch_offsets_q if batch_offsets_q is not None else None,
229+
batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None,
230+
)
231+
232+
handle_ = _create_cudnn_handle(torch.cuda.current_stream())
233+
234+
var_map = {
235+
UIDs.Q_UID.value: q,
236+
UIDs.K_UID.value: k_cache,
237+
UIDs.V_UID.value: v_cache,
238+
UIDs.O_UID.value: out,
239+
}
240+
if actual_seq_lens_q is not None:
241+
var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q
242+
if actual_seq_lens_kv is not None:
243+
var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv
244+
245+
if batch_offsets_q is not None:
246+
var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q
247+
if batch_offsets_o is not None:
248+
var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o
249+
250+
if block_tables is not None:
251+
var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables
252+
var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables
253+
254+
graph.execute(var_map, workspace=workspace_buffer, handle=handle_)
255+
256+
return out
257+
8258

9259
def cudnn_batch_decode_with_kv_cache(
10260
q: torch.Tensor,
@@ -37,7 +287,6 @@ def cudnn_batch_decode_with_kv_cache(
37287
is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph
38288
batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU
39289
out: Optional pre-allocated output tensor
40-
lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
41290
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
42291
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
43292
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -53,30 +302,51 @@ def cudnn_batch_decode_with_kv_cache(
53302
"""
54303

55304
bs = q.shape[0]
56-
s_q = 1
57305
h_qo = q.shape[1]
58306
d_vo = v_cache.shape[3]
59307

60308
if out is None:
61309
out = torch.empty(bs, h_qo, d_vo, device=q.device, dtype=q.dtype)
62310

63-
actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)
311+
if not CUDNN_AVAILABLE:
312+
actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)
64313

65-
run_func = get_cudnn_fmha_gen_module().decode
66-
run_func(
67-
max_sequence_kv,
68-
q,
69-
k_cache,
70-
v_cache,
71-
scale,
72-
workspace_buffer,
73-
actual_seq_lens_kv,
74-
actual_seq_lens_kv_gpu,
75-
block_tables,
76-
out,
77-
batch_offsets_q,
78-
batch_offsets_o,
79-
is_cuda_graph_compatible,
80-
)
314+
run_func = get_cudnn_fmha_gen_module().decode
315+
run_func(
316+
max_sequence_kv,
317+
q,
318+
k_cache,
319+
v_cache,
320+
scale,
321+
workspace_buffer,
322+
actual_seq_lens_kv,
323+
actual_seq_lens_kv_gpu,
324+
block_tables,
325+
out,
326+
batch_offsets_q,
327+
batch_offsets_o,
328+
is_cuda_graph_compatible,
329+
)
330+
else:
331+
actual_seq_lens_q = torch.ones(
332+
(bs, 1, 1, 1), device=q.device, dtype=torch.int32
333+
)
334+
block_size = k_cache.shape[2]
335+
336+
_batch_decode_with_kv_cache(
337+
q=q,
338+
k_cache=k_cache,
339+
v_cache=v_cache,
340+
scale=scale,
341+
workspace_buffer=workspace_buffer,
342+
max_sequence_kv=max_sequence_kv,
343+
actual_seq_lens_q=actual_seq_lens_q,
344+
actual_seq_lens_kv=actual_seq_lens_kv,
345+
block_tables=block_tables,
346+
batch_offsets_q=batch_offsets_q,
347+
batch_offsets_o=batch_offsets_o,
348+
block_size=block_size,
349+
out=out,
350+
)
81351

82352
return out

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def generate_build_meta(aot_build_meta: dict) -> None:
6262
"einops",
6363
"nvidia-nvshmem-cu12",
6464
"nvidia-cudnn-cu12",
65-
'nvidia-cudnn-frontend; platform_machine == "x86_64" or platform_machine == "AMD64"',
65+
"nvidia-cudnn-frontend>=1.13.0",
6666
]
6767
generate_build_meta({})
6868

tests/test_cudnn_decode.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import flashinfer
88

99

10-
@pytest.mark.parametrize("batch_size", [4, 8, 17, 64])
11-
@pytest.mark.parametrize("s_kv", [8, 40, 1024])
12-
@pytest.mark.parametrize("page_size", [1, 8])
13-
@pytest.mark.parametrize("num_kv_heads", [4])
14-
@pytest.mark.parametrize("num_qo_heads", [4, 32])
15-
@pytest.mark.parametrize("is_cuda_graph_compatible", [False, True])
10+
@pytest.mark.parametrize("batch_size", [8, 16, 32])
11+
@pytest.mark.parametrize("s_kv", [512, 8192])
12+
@pytest.mark.parametrize("page_size", [16])
13+
@pytest.mark.parametrize("num_kv_heads", [8])
14+
@pytest.mark.parametrize("num_qo_heads", [32])
15+
@pytest.mark.parametrize("is_cuda_graph_compatible", [True, False])
1616
def test_cudnn_decode(
1717
batch_size,
1818
s_kv,
@@ -79,7 +79,11 @@ def test_cudnn_decode(
7979

8080
# Actual sequence lengths (should be randomized across batches. )
8181
actual_seq_lens_kv = torch.randint(
82-
0, s_kv, (batch_size, 1, 1, 1), dtype=torch.int32
82+
0, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
83+
)
84+
85+
ragged_q = torch.arange(0, batch_size + 1, device=device) * (
86+
num_qo_heads * head_dim
8387
)
8488

8589
workspace_buffer_size = math.ceil(
@@ -106,6 +110,8 @@ def test_cudnn_decode(
106110
actual_seq_lens_kv=actual_seq_lens_kv,
107111
block_tables=block_tables,
108112
is_cuda_graph_compatible=is_cuda_graph_compatible,
113+
batch_offsets_q=ragged_q,
114+
batch_offsets_o=ragged_q,
109115
)
110116

111117
actual_seq_lens_kv_device = actual_seq_lens_kv.to(device)

0 commit comments

Comments
 (0)