diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 893bba2db..5fddf86e7 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -798,6 +798,7 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): .int() .to(device) ) + # Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr kv_indptr = ( torch.cat( diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ea140fc90..1bde04489 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -22,9 +22,6 @@ import torch -from .cudnn import ( - cudnn_batch_prefill_with_kv_cache as cudnn_batch_prefill_with_kv_cache, -) from .jit import ( gen_batch_prefill_module, gen_customize_batch_prefill_module, @@ -36,6 +33,7 @@ setup_metainfo_loader, trtllm_gen_fmha_module, ) +from .cudnn import cudnn_batch_prefill_with_kv_cache from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens from .quantization import packbits, segment_packbits from .utils import ( @@ -1368,7 +1366,7 @@ def __init__( mask will be used in attention computation. backend : str - The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``. + The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``. If set to ``auto``, the wrapper will automatically choose the backend based on the device architecture and kernel availability. @@ -1392,6 +1390,9 @@ def __init__( self._jit_module = None self._kv_layout = kv_layout + if backend == "cudnn": + assert kv_layout == "NHD", "CUDNN backend only supports NHD layout" + self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None @@ -1456,6 +1457,11 @@ def __init__( self._mask_indptr_buf = mask_indptr_buf self._max_total_num_rows = None self._backend = backend + self._plan_info = None + self._cached_module = None + self._seq_lens_kv = None + self._seq_lens_q = None + self._block_tables = None @property def is_cuda_graph_enabled(self) -> bool: @@ -1514,7 +1520,10 @@ def plan( token_pos_in_items_len: int = 0, max_item_len_ptr: Optional[torch.Tensor] = None, seq_lens: Optional[torch.Tensor] = None, + seq_lens_q: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None, + max_token_per_sequence: Optional[int] = None, + max_sequence_kv: Optional[int] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. @@ -1605,8 +1614,15 @@ def plan( a uint16 vector contains the max token length of all items for each prompt seq_lens: Optional[torch.Tensor] A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. + seq_lens_q: Optional[torch.Tensor] + A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``. + If not provided, will be set to the same value as ``seq_lens``. block_tables: Optional[torch.Tensor] A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``. + max_token_per_sequence: Optional[int], + Required for cudnn backend. This is the scalar max token length of each sequence. + max_sequence_kv: Optional[int], + Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache. Note ---- @@ -1655,22 +1671,28 @@ def plan( self._max_item_len_ptr = max_item_len_ptr # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors - qo_indptr_host = qo_indptr.to("cpu") - paged_kv_indptr_host = paged_kv_indptr.to("cpu") - paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") - if seq_lens is None: - kv_lens_arr_host = get_seq_lens( - paged_kv_indptr_host, paged_kv_last_page_len_host, page_size - ) + if max_token_per_sequence is not None: + self._max_q_len = max_token_per_sequence else: - kv_lens_arr_host = seq_lens.cpu() - self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( - kv_lens_arr_host, non_blocking=non_blocking - ) - self._max_q_len = max(qo_indptr_host).item() - self._max_kv_len = max(kv_lens_arr_host).item() + qo_indptr_host = qo_indptr.to("cpu") + self._max_q_len = max(qo_indptr_host).item() + total_num_rows = qo_indptr_host[-1] - total_num_rows = qo_indptr_host[-1] + if max_sequence_kv is not None: + self._max_kv_len = max_sequence_kv + else: + paged_kv_indptr_host = paged_kv_indptr.to("cpu") + paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + if seq_lens is None: + kv_lens_arr_host = get_seq_lens( + paged_kv_indptr_host, paged_kv_last_page_len_host, page_size + ) + else: + kv_lens_arr_host = seq_lens.cpu().flatten() + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking + ) + self._max_kv_len = max(kv_lens_arr_host).item() if self.is_cuda_graph_enabled: if self._max_total_num_rows is None: @@ -1759,23 +1781,23 @@ def plan( q_data_type, kv_data_type, ) + if self._backend != "cudnn": + get_module_args = ( + q_data_type, + kv_data_type, + q_data_type, + paged_kv_indptr.dtype, + head_dim_qk, + head_dim_vo, + PosEncodingMode[pos_encoding_mode].value, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + use_fp16_qk_reduction, + ) - get_module_args = ( - q_data_type, - kv_data_type, - q_data_type, - paged_kv_indptr.dtype, - head_dim_qk, - head_dim_vo, - PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - use_fp16_qk_reduction, - ) - - self._cached_module = get_batch_prefill_module( - self._backend, *get_module_args - ) + self._cached_module = get_batch_prefill_module( + self._backend, *get_module_args + ) if self._backend == "fa3" or self._backend == "trtllm-gen": if page_size != 1: @@ -1793,7 +1815,7 @@ def plan( ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) paged_kv_indptr_host = vector_sparse_indptr_host - self._block_tables: Optional[torch.Tensor] = block_tables + self._block_tables = block_tables if self._backend == "trtllm-gen": assert self._kv_layout == "HND" assert logits_soft_cap == 0.0 @@ -1811,28 +1833,32 @@ def plan( block_id = paged_kv_indptr_host[0] for i in range(batch_size): num_blocks_needed = blocks_per_seq[i] + assert self._block_tables is not None, ( + "block_tables is not initialized" + ) self._block_tables[i, :num_blocks_needed] = paged_kv_indices[ block_id : block_id + num_blocks_needed ] block_id += num_blocks_needed - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - paged_kv_indptr_host, - kv_lens_arr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim_qk, - head_dim_vo, - causal, - ) + if self._cached_module is not None: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + paged_kv_indptr_host, + kv_lens_arr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim_qk, + head_dim_vo, + causal, + ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode @@ -1842,6 +1868,8 @@ def plan( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + self._seq_lens_kv = seq_lens + self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens begin_forward = plan @@ -2042,62 +2070,90 @@ def run( sparse_indices = self._paged_kv_indices_buf sparse_indptr = self._paged_kv_indptr_buf - run_args = [ - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, - q, - k_cache, - v_cache, - self._qo_indptr_buf, - sparse_indptr, - sparse_indices, - self._paged_kv_last_page_len_buf, - out, - lse, - mask_mode, - TensorLayout[self._kv_layout].value, - window_left, - enable_pdl, - ] - if self._jit_module is not None: - run_args.extend(list(args)) + if self._backend == "cudnn": + if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: + self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1) + + if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1: + self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1) + + cudnn_batch_prefill_with_kv_cache( + q, + k_cache, # Need to be changed + v_cache, # Need to be changed + self._sm_scale, + self._float_workspace_buffer, + actual_seq_lens_q=self._seq_lens_q, + actual_seq_lens_kv=self._seq_lens_kv, + max_token_per_sequence=self._max_q_len, + max_sequence_kv=self._max_kv_len, + block_tables=self._block_tables, + causal=self._causal, + return_lse=return_lse, + batch_offsets_q=self._qo_indptr_buf, + batch_offsets_o=self._qo_indptr_buf, + out=out, + lse=lse, + ) else: - run_args += [ - self._custom_mask_buf, - self._mask_indptr_buf, - _get_cache_alibi_slopes_buf(q.shape[1], q.device), - self._prefix_len_ptr, - self._token_pos_in_items_ptr, - self._max_item_len_ptr, - logits_soft_cap, - sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v - rope_scale, - rope_theta, - self._token_pos_in_items_len, - self._num_qo_heads, - self._num_kv_heads, - self._block_tables, - self._kv_lens_buffer, - page_size, - self._max_q_len, - self._max_kv_len, - self._batch_size, + assert self._plan_info is not None, "plan info is not initialized" + run_args = [ + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + k_cache, + v_cache, self._qo_indptr_buf, - self._vector_sparse_indptr_buffer, - sinks, + sparse_indptr, + sparse_indices, + self._paged_kv_last_page_len_buf, + out, + lse, + mask_mode, + TensorLayout[self._kv_layout].value, + window_left, + enable_pdl, ] - - self._cached_module.paged_run(*run_args) - if v_scale is not None: - # TODO(Zihao): fused into kernel - if is_float8(out): - out = (out.to(torch.float32) * v_scale).to(out.dtype) + if self._jit_module is not None: + run_args.extend(list(args)) else: - out *= v_scale + run_args += [ + self._custom_mask_buf, + self._mask_indptr_buf, + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + self._prefix_len_ptr, + self._token_pos_in_items_ptr, + self._max_item_len_ptr, + logits_soft_cap, + sm_scale, + None, # scale_q, not supported yet + None, # scale_k + None, # scale_v + rope_scale, + rope_theta, + self._token_pos_in_items_len, + self._num_qo_heads, + self._num_kv_heads, + self._block_tables, + self._kv_lens_buffer, + page_size, + self._max_q_len, + self._max_kv_len, + self._batch_size, + self._qo_indptr_buf, + self._vector_sparse_indptr_buffer, + sinks, + ] + + assert self._cached_module is not None, "cached module is not initialized" + self._cached_module.paged_run(*run_args) + if v_scale is not None: + # TODO(Zihao): fused into kernel + if is_float8(out): + out = (out.to(torch.float32) * v_scale).to(out.dtype) + else: + out *= v_scale return (out, lse) if return_lse else out run_return_lse = functools.partialmethod(run, return_lse=True) @@ -2351,6 +2407,7 @@ def __init__( self._mask_indptr_buf = mask_indptr_buf self._max_total_num_rows = None self._backend = backend + self._cached_module = None @property def is_cuda_graph_enabled(self) -> bool: @@ -2621,6 +2678,7 @@ def plan( ) self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() else: + assert self._cached_module is not None, "cached module is not initialized" self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -2845,6 +2903,7 @@ def run( self._token_pos_in_items_len, ] + assert self._cached_module is not None, "cached module is not initialized" self._cached_module.ragged_run(*run_args) return (out, lse) if return_lse else out diff --git a/tests/test_cudnn_prefill.py b/tests/test_cudnn_prefill.py index 6b54cd275..a4db63409 100644 --- a/tests/test_cudnn_prefill.py +++ b/tests/test_cudnn_prefill.py @@ -80,46 +80,6 @@ def test_cudnn_prefill( (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), ) - # Now initialize the page tables - block_tables = torch.tensor( - [ - [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] - for i in range(batch_size) - ], - dtype=torch.int, - device=device, - ) - - # Initialize scale - scale = float(1.0 / (head_dim**0.5)) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( - q, - k_cache, - v_cache, - scale, - workspace_buffer, - max_token_per_sequence=s_qo, - max_sequence_kv=s_kv, - actual_seq_lens_q=actual_seq_lens_q, - actual_seq_lens_kv=actual_seq_lens_kv, - block_tables=block_tables, - causal=causal, - return_lse=return_lse, - is_cuda_graph_compatible=is_cuda_graph_compatible, - batch_offsets_q=q_indptr, - batch_offsets_o=q_indptr, - ) - - qo_indptr = torch.cat( - [ - torch.tensor([0], device=device), - torch.cumsum(actual_seq_lens_q.view(-1), dim=0), - ] - ).int() - kv_indptr = torch.cat( [ torch.tensor([0], device=device), @@ -148,6 +108,53 @@ def test_cudnn_prefill( actual_seq_lens_kv.flatten() % page_size, ).int() + # Now initialize the page tables + block_tables = torch.tensor( + [ + [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + for i in range(batch_size) + ], + dtype=torch.int, + device=device, + ) + + # Initialize scale + scale = float(1.0 / (head_dim**0.5)) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + + wrapper_cudnn = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD", backend="cudnn" + ) + wrapper_cudnn.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=torch.bfloat16, + seq_lens=actual_seq_lens_kv, + seq_lens_q=actual_seq_lens_q, + sm_scale=scale, + max_token_per_sequence=s_qo, + max_sequence_kv=s_kv, + block_tables=block_tables, + ) + + output = wrapper_cudnn.run(q, (k_cache, v_cache)) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + # Workspace buffer workspace_buffer_ref = torch.empty( 128 * 1024 * 1024, dtype=torch.int8, device=device