From 1ea7c32322a3c0169a3feacfa74993aeacf2b19e Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Wed, 6 Aug 2025 21:34:18 -0400 Subject: [PATCH 1/8] init --- flashinfer/decode.py | 7 +++---- flashinfer/prefill.py | 6 +++--- tests/test_trtllm_gen_context.py | 2 -- tests/test_trtllm_gen_decode.py | 1 - 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 441fa0d2a..2fedab2c9 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1982,7 +1982,6 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, bmm1_scale: float, bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later window_left: int = -1, @@ -2011,9 +2010,6 @@ def trtllm_batch_decode_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` - max_seq_len : int - max sequence length for kv_cache - bmm1_scale : float fused scale for bmm1 input. @@ -2110,6 +2106,9 @@ def trtllm_batch_decode_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") + page_size = k_cache.shape[3] + num_pages = block_tables.shape[1] + max_seq_len = num_pages * page_size run_func( out, out_scale_factor, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 42db478e6..278bfe7c5 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3052,7 +3052,6 @@ def trtllm_batch_context_with_kv_cache( block_tables: torch.Tensor, seq_lens: torch.Tensor, max_q_len: int, - max_kv_len: int, bmm1_scale: float, bmm2_scale: float, batch_size: int, @@ -3081,8 +3080,6 @@ def trtllm_batch_context_with_kv_cache( A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int max sequence length for query - max_kv_len : int - max sequence length for kv_cache bmm1_scale : float fused scale for bmm1 input. bmm2_scale : float @@ -3178,6 +3175,9 @@ def trtllm_batch_context_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") + page_size = k_cache.shape[3] + num_pages = block_tables.shape[1] + max_kv_len = num_pages * page_size run_func( out, out_scale_factor, diff --git a/tests/test_trtllm_gen_context.py b/tests/test_trtllm_gen_context.py index de908b001..f647f358a 100644 --- a/tests/test_trtllm_gen_context.py +++ b/tests/test_trtllm_gen_context.py @@ -191,7 +191,6 @@ def test_trtllm_batch_context_wrapper( block_tables=block_tables, seq_lens=seq_lens, max_q_len=qo_len, - max_kv_len=kv_len, bmm1_scale=q_scale / math.sqrt(head_dim), bmm2_scale=1, batch_size=batch_size, @@ -383,7 +382,6 @@ def test_trtllm_batch_prefill( block_tables, seq_lens_gpu, max_q_len, - max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale batch_size, diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 46f08c779..05157ee63 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -276,7 +276,6 @@ def test_trtllm_batch_decode_fmha( workspace_buffer, block_tables, seq_lens_gpu, - max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale window_left, # window_left From 8743781ff843fe5549d6c73c5186e44d393c4ba6 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 04:09:25 -0400 Subject: [PATCH 2/8] upd mla interface --- flashinfer/decode.py | 3 +-- tests/test_trtllm_gen_decode.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 2fedab2c9..8359e5d23 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2187,7 +2187,6 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_rope_head_dim: int, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, out: Optional[torch.Tensor] = None, bmm1_scale: Optional[float] = 1.0, bmm2_scale: Optional[float] = 1.0, @@ -2205,7 +2204,6 @@ def trtllm_batch_decode_with_kv_cache_mla( qk_rope_head_dim: qk_rope_head_dim, must be 64 block_tables: page_table of kv cache, [batch_size, num_pages] seq_lens: query_len - max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally bmm1_scale: fused scale for mla bmm1 input. bmm2_scale: fused scale for mla bmm2 input. @@ -2262,6 +2260,7 @@ def trtllm_batch_decode_with_kv_cache_mla( query.device, "out", ) + max_seq_len = block_tables.shape[1] * block_size if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: # dynamic scale factors diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 05157ee63..efa2bf4c7 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -494,7 +494,6 @@ def test_trtllm_batch_decode_mla( qk_rope_head_dim=qk_rope_head_dim, block_tables=block_tables, seq_lens=seq_lens_tensor, - max_seq_len=max_seq_len, bmm1_scale=scale / ((128 + 64) ** 0.5), bmm2_scale=1.0, bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, From bb695872a5cddb2b240152e6d76784b91f5427c6 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 04:35:51 -0400 Subject: [PATCH 3/8] upd decode --- flashinfer/decode.py | 7 +------ tests/test_trtllm_gen_decode.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 8359e5d23..cc2112e2d 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1284,7 +1284,6 @@ def run( self._block_tables, self._kv_lens_buffer, page_size, - self._max_kv_len, sinks, ] @@ -1799,7 +1798,6 @@ def _paged_run( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later bmm2_scale: float, window_left: int = -1, @@ -1821,7 +1819,7 @@ def _paged_run( workspace_buffer, block_tables, seq_lens, - max_seq_len, + block_tables.shape[1] * k_cache.shape[3], # max_seq_len bmm1_scale, bmm2_scale, -1, # o_sf_scale @@ -1900,7 +1898,6 @@ def paged_run( block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, - max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None @@ -1910,7 +1907,6 @@ def paged_run( assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None - assert max_kv_len is not None o = module._paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, @@ -1918,7 +1914,6 @@ def paged_run( int_workspace_buffer, block_tables, kv_lens_buffer, - max_kv_len, sm_scale, 1.0, # NOTE(Siyuan): update this to expose bmm2 scale window_left, diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index efa2bf4c7..127f6fb77 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -580,3 +580,17 @@ def test_trtllm_batch_decode_mla( print("output:", output) print("o_ref:", o_ref) raise e + + +if __name__ == "__main__": + test_trtllm_batch_decode_fmha( + kv_layout="HND", + batch_size=4, + page_size=32, + num_kv_heads=2, + head_grp_size=4, + window_left=-1, + q_dtype="half", + o_dtype="half", + kv_cache_dtype="half", + ) From 53c578351621724998e8678121408b1463af6ce1 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 04:43:01 -0400 Subject: [PATCH 4/8] fmt --- tests/test_trtllm_gen_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 8ed45d77e..b450eb178 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -492,7 +492,7 @@ def test_trtllm_batch_decode_mla( qk_rope_head_dim=qk_rope_head_dim, block_tables=block_tables, seq_lens=seq_lens_tensor, - bmm1_scale=scale / ((128 + 64) ** 0.5), + bmm1_scale=scale / ((128 + 64) ** 0.5), bmm2_scale=1.0, bmm1_scale_log2_tensor=bmm1_log2_scale_tensor, bmm2_scale_tensor=bmm2_scale_tensor, From 1df7d3e27af5cab20b3d46f1b699b35252c2a5ba Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 04:55:44 -0400 Subject: [PATCH 5/8] revert decode fmha change --- flashinfer/decode.py | 13 +++++++++---- tests/test_trtllm_gen_decode.py | 15 +-------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index b5ceba6e3..4e4935511 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1281,6 +1281,7 @@ def run( self._block_tables, self._kv_lens_buffer, page_size, + self._max_kv_len, sinks, ] @@ -1807,6 +1808,7 @@ def _paged_run( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, + max_seq_len: int, bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later bmm2_scale: float, window_left: int = -1, @@ -1828,7 +1830,7 @@ def _paged_run( workspace_buffer, block_tables, seq_lens, - block_tables.shape[1] * k_cache.shape[3], # max_seq_len + max_seq_len, bmm1_scale, bmm2_scale, -1, # o_sf_scale @@ -1895,6 +1897,7 @@ def paged_run( block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, + max_seq_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None @@ -1904,6 +1907,7 @@ def paged_run( assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None + assert max_seq_len is not None o = module._paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, @@ -1974,6 +1978,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, + max_seq_len: int, bmm1_scale: float, bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later window_left: int = -1, @@ -2002,6 +2007,9 @@ def trtllm_batch_decode_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` + max_seq_len : int + max sequence length for kv_cache + bmm1_scale : float fused scale for bmm1 input. @@ -2098,9 +2106,6 @@ def trtllm_batch_decode_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - page_size = k_cache.shape[3] - num_pages = block_tables.shape[1] - max_seq_len = num_pages * page_size run_func( out, out_scale_factor, diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index b450eb178..f2b9d52df 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -274,6 +274,7 @@ def test_trtllm_batch_decode_fmha( workspace_buffer, block_tables, seq_lens_gpu, + max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale window_left, # window_left @@ -578,17 +579,3 @@ def test_trtllm_batch_decode_mla( print("output:", output) print("o_ref:", o_ref) raise e - - -if __name__ == "__main__": - test_trtllm_batch_decode_fmha( - kv_layout="HND", - batch_size=4, - page_size=32, - num_kv_heads=2, - head_grp_size=4, - window_left=-1, - q_dtype="half", - o_dtype="half", - kv_cache_dtype="half", - ) From e7a5ff9272782f851fb75d05e18887519d378218 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 04:57:08 -0400 Subject: [PATCH 6/8] upd revert --- flashinfer/decode.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 4e4935511..dfe1b4be8 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1897,7 +1897,7 @@ def paged_run( block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, - max_seq_len: Optional[int] = None, + max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None @@ -1907,7 +1907,7 @@ def paged_run( assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None - assert max_seq_len is not None + assert max_kv_len is not None o = module._paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, @@ -1915,6 +1915,7 @@ def paged_run( int_workspace_buffer, block_tables, kv_lens_buffer, + max_kv_len, sm_scale, 1.0, # NOTE(Siyuan): update this to expose bmm2 scale window_left, From b0010b70af5216be3dbb57fa1367a5c4303a3150 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Sun, 10 Aug 2025 05:04:06 -0400 Subject: [PATCH 7/8] upd context --- flashinfer/prefill.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 581517f8d..65b0cc097 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3118,6 +3118,7 @@ def trtllm_batch_context_with_kv_cache( if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache + page_size = k_cache.shape[2] else: if kv_cache.shape[1] == 1: k_cache, v_cache = kv_cache, kv_cache @@ -3128,6 +3129,7 @@ def trtllm_batch_context_with_kv_cache( # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1) + page_size = k_cache.shape[3] run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context sm_count = get_device_sm_count(query.device) @@ -3181,7 +3183,6 @@ def trtllm_batch_context_with_kv_cache( else: raise ValueError(f"Invalid out_dtype: {out_dtype}") - page_size = k_cache.shape[3] num_pages = block_tables.shape[1] max_kv_len = num_pages * page_size run_func( From 3c26b7443ca4f489f396346cb764ca4f27d9e673 Mon Sep 17 00:00:00 2001 From: Avery Yingyi Huang Date: Mon, 11 Aug 2025 02:43:04 -0400 Subject: [PATCH 8/8] restore decode --- flashinfer/decode.py | 19 ++++++++----------- tests/test_trtllm_gen_decode.py | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 726717ddb..e9aa979d8 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1282,7 +1282,7 @@ def run( self._block_tables, self._kv_lens_buffer, page_size, - self._max_kv_len, + # self._max_kv_len, sinks, ] @@ -1809,7 +1809,7 @@ def _paged_run( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + # max_seq_len: int, bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later bmm2_scale: float, window_left: int = -1, @@ -1831,7 +1831,7 @@ def _paged_run( workspace_buffer, block_tables, seq_lens, - max_seq_len, + block_tables.shape[1] * k_cache.shape[3], # max_seq_len bmm1_scale, bmm2_scale, -1, # o_sf_scale @@ -1898,7 +1898,7 @@ def paged_run( block_tables: Optional[torch.Tensor] = None, kv_lens_buffer: Optional[torch.Tensor] = None, page_size: Optional[int] = None, - max_kv_len: Optional[int] = None, + # max_kv_len: Optional[int] = None, sinks: Optional[torch.Tensor] = None, ) -> None: assert maybe_lse is None @@ -1908,7 +1908,7 @@ def paged_run( assert block_tables is not None assert kv_lens_buffer is not None assert page_size is not None - assert max_kv_len is not None + # assert max_kv_len is not None o = module._paged_run( q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect paged_k_cache, @@ -1916,7 +1916,7 @@ def paged_run( int_workspace_buffer, block_tables, kv_lens_buffer, - max_kv_len, + # max_kv_len, sm_scale, 1.0, # NOTE(Siyuan): update this to expose bmm2 scale window_left, @@ -1980,7 +1980,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + # max_seq_len: int, bmm1_scale: float, bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later window_left: int = -1, @@ -2009,9 +2009,6 @@ def trtllm_batch_decode_with_kv_cache( seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` - max_seq_len : int - max sequence length for kv_cache - bmm1_scale : float fused scale for bmm1 input. @@ -2117,7 +2114,7 @@ def trtllm_batch_decode_with_kv_cache( workspace_buffer, block_tables, seq_lens, - max_seq_len, + block_tables.shape[1] * k_cache.shape[3], # max_seq_len bmm1_scale, bmm2_scale, o_sf_scale or -1.0, diff --git a/tests/test_trtllm_gen_decode.py b/tests/test_trtllm_gen_decode.py index 4d7cdeec7..d1a8fdf20 100644 --- a/tests/test_trtllm_gen_decode.py +++ b/tests/test_trtllm_gen_decode.py @@ -274,7 +274,7 @@ def test_trtllm_batch_decode_fmha( workspace_buffer, block_tables, seq_lens_gpu, - max_seq_len, + # max_seq_len, q_scale * k_scale * sm_scale, # bmm1_scale v_scale / o_scale, # bmm2_scale window_left, # window_left @@ -433,7 +433,7 @@ def test_trtllm_batch_decode_mla( # Sequence lengths and block tables seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)] seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) + # max_seq_len = max(seq_lens) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size