diff --git a/flashinfer/decode.py b/flashinfer/decode.py index ed988f8fe..e9aa979d8 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1282,7 +1282,7 @@ def run( self._block_tables, self._kv_lens_buffer, page_size, - self._max_kv_len, + # self._max_kv_len, sinks, ] @@ -1809,7 +1809,7 @@ def _paged_run( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + # max_seq_len: int, bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later bmm2_scale: float, window_left: int = -1, @@ -1831,7 +1831,7 @@ def _paged_run( workspace_buffer, block_tables, seq_lens, - max_seq_len, + block_tables.shape[1] * k_cache.shape[3], # max_seq_len bmm1_scale, bmm2_scale, -1, # o_sf_scale @@ -1898,7 +1898,7 @@ def paged_run( block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, - max_kv_len: Optional[int] = None, + # max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None @@ -1908,7 +1908,7 @@ def paged_run( assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None - assert max_kv_len is not None + # assert max_kv_len is not None o = module._paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, @@ -1916,7 +1916,7 @@ def paged_run( int_workspace_buffer, block_tables, kv_lens_buffer, - max_kv_len, + # max_kv_len, sm_scale, 1.0, # NOTE(Siyuan): update this to expose bmm2 scale window_left, @@ -1980,7 +1980,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + # max_seq_len: int, bmm1_scale: float, bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later window_left: int = -1, @@ -2009,9 +2009,6 @@ def trtllm_batch_decode_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` - max_seq_len : int - max sequence length for kv_cache - bmm1_scale : float fused scale for bmm1 input. @@ -2117,7 +2114,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer, block_tables, seq_lens, - max_seq_len, + block_tables.shape[1] * k_cache.shape[3], # max_seq_len bmm1_scale, bmm2_scale, o_sf_scale or -1.0, @@ -2186,7 +2183,6 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_rope_head_dim: int, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, out: Optional[torch.Tensor] = None, bmm1_scale: Optional[float] = 1.0, bmm2_scale: Optional[float] = 1.0, @@ -2204,7 +2200,6 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_rope_head_dim: qk_rope_head_dim, must be 64 block_tables: page_table of kv cache, [batch_size, num_pages] seq_lens: query_len - max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally bmm1_scale: fused scale for mla bmm1 input. bmm2_scale: fused scale for mla bmm2 input. @@ -2261,6 +2256,7 @@ def trtllm_batch_decode_with_kv_cache_mla( query.device, "out", ) + max_seq_len = block_tables.shape[1] * block_size if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: # dynamic scale factors diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ea140fc90..01cfc099a 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3063,7 +3063,6 @@ def trtllm_batch_context_with_kv_cache( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, - max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, @@ -3092,8 +3091,6 @@ def trtllm_batch_context_with_kv_cache( A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int max sequence length for query - max_kv_len : int - max sequence length for kv_cache bmm1_scale : float fused scale for bmm1 input. bmm2_scale : float @@ -3126,6 +3123,7 @@ def trtllm_batch_context_with_kv_cache( if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache + page_size = k_cache.shape[2] else: if kv_cache.shape[1] == 1: k_cache, v_cache = kv_cache, kv_cache @@ -3136,6 +3134,7 @@ def trtllm_batch_context_with_kv_cache( # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) + page_size = k_cache.shape[3] run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context sm_count = get_device_sm_count(query.device) @@ -3189,6 +3188,8 @@ def trtllm_batch_context_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") + num_pages = block_tables.shape[1] + max_kv_len = num_pages * page_size run_func( out, out_scale_factor, diff --git a/tests/test_trtllm_gen_context.py b/tests/test_trtllm_gen_context.py index 0040b26a0..1bf132835 100644 --- a/tests/test_trtllm_gen_context.py +++ b/tests/test_trtllm_gen_context.py @@ -191,7 +191,6 @@ def test_trtllm_batch_context_wrapper( block_tables=block_tables, seq_lens=seq_lens, max_q_len=qo_len, - max_kv_len=kv_len, bmm1_scale=q_scale / math.sqrt(head_dim), bmm2_scale=1, batch_size=batch_size, @@ -381,7 +380,6 @@ def test_trtllm_batch_prefill( block_tables, seq_lens_gpu, max_q_len, - max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale batch_size, diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 48479b258..d1a8fdf20 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -274,7 +274,7 @@ def test_trtllm_batch_decode_fmha( workspace_buffer, block_tables, seq_lens_gpu, - max_seq_len, + # max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale window_left, # window_left @@ -433,7 +433,7 @@ def test_trtllm_batch_decode_mla( # Sequence lengths and block tables seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)] seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) + # max_seq_len = max(seq_lens) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size @@ -497,7 +497,6 @@ def test_trtllm_batch_decode_mla( qk_rope_head_dim=qk_rope_head_dim, block_tables=block_tables, seq_lens=seq_lens_tensor, - max_seq_len=max_seq_len, bmm1_scale=scale / ((128 + 64) ** 0.5), bmm2_scale=1.0, bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,