Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ def run(
self._block_tables,
self._kv_lens_buffer,
page_size,
self._max_kv_len,
# self._max_kv_len,
sinks,
]

Expand Down Expand Up @@ -1809,7 +1809,7 @@ def _paged_run(
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
# max_seq_len: int,
bmm1_scale: float, # todo(Yingyi): add dynamic scale tensor later
bmm2_scale: float,
window_left: int = -1,
Expand All @@ -1831,7 +1831,7 @@ def _paged_run(
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
block_tables.shape[1] * k_cache.shape[3], # max_seq_len
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
Expand Down Expand Up @@ -1898,7 +1898,7 @@ def paged_run(
block_tables: Optional[torch.Tensor] = None,
kv_lens_buffer: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_kv_len: Optional[int] = None,
# max_kv_len: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
assert maybe_lse is None
Expand All @@ -1908,15 +1908,15 @@ def paged_run(
assert block_tables is not None
assert kv_lens_buffer is not None
assert page_size is not None
assert max_kv_len is not None
# assert max_kv_len is not None
o = module._paged_run(
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
paged_k_cache,
paged_v_cache,
int_workspace_buffer,
block_tables,
kv_lens_buffer,
max_kv_len,
# max_kv_len,
sm_scale,
1.0, # NOTE(Siyuan): update this to expose bmm2 scale
window_left,
Expand Down Expand Up @@ -1980,7 +1980,7 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
# max_seq_len: int,
bmm1_scale: float,
bmm2_scale: float, # todo(Yingyi): add dynamic scale tensor later
window_left: int = -1,
Expand Down Expand Up @@ -2009,9 +2009,6 @@ def trtllm_batch_decode_with_kv_cache(
seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``

max_seq_len : int
max sequence length for kv_cache

bmm1_scale : float
fused scale for bmm1 input.

Expand Down Expand Up @@ -2117,7 +2114,7 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
block_tables.shape[1] * k_cache.shape[3], # max_seq_len
bmm1_scale,
bmm2_scale,
o_sf_scale or -1.0,
Expand Down Expand Up @@ -2186,7 +2183,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
qk_rope_head_dim: int,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
out: Optional[torch.Tensor] = None,
bmm1_scale: Optional[float] = 1.0,
bmm2_scale: Optional[float] = 1.0,
Expand All @@ -2204,7 +2200,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
bmm2_scale: fused scale for mla bmm2 input.
Expand Down Expand Up @@ -2261,6 +2256,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
query.device,
"out",
)
max_seq_len = block_tables.shape[1] * block_size

if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
# dynamic scale factors
Expand Down
7 changes: 4 additions & 3 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,6 @@ def trtllm_batch_context_with_kv_cache(
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
bmm1_scale: float,
bmm2_scale: float,
batch_size: int,
Expand Down Expand Up @@ -3092,8 +3091,6 @@ def trtllm_batch_context_with_kv_cache(
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
max_q_len : int
max sequence length for query
max_kv_len : int
max sequence length for kv_cache
bmm1_scale : float
fused scale for bmm1 input.
bmm2_scale : float
Expand Down Expand Up @@ -3126,6 +3123,7 @@ def trtllm_batch_context_with_kv_cache(

if isinstance(kv_cache, tuple):
k_cache, v_cache = kv_cache
page_size = k_cache.shape[2]
else:
if kv_cache.shape[1] == 1:
k_cache, v_cache = kv_cache, kv_cache
Expand All @@ -3136,6 +3134,7 @@ def trtllm_batch_context_with_kv_cache(
# NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
# it doesn't change underlying storage
k_cache, v_cache = kv_cache.unbind(dim=1)
page_size = k_cache.shape[3]

run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_context
sm_count = get_device_sm_count(query.device)
Expand Down Expand Up @@ -3189,6 +3188,8 @@ def trtllm_batch_context_with_kv_cache(
else:
raise ValueError(f"Invalid out_dtype: {out_dtype}")

num_pages = block_tables.shape[1]
max_kv_len = num_pages * page_size
run_func(
out,
out_scale_factor,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_trtllm_gen_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def test_trtllm_batch_context_wrapper(
block_tables=block_tables,
seq_lens=seq_lens,
max_q_len=qo_len,
max_kv_len=kv_len,
bmm1_scale=q_scale / math.sqrt(head_dim),
bmm2_scale=1,
batch_size=batch_size,
Expand Down Expand Up @@ -381,7 +380,6 @@ def test_trtllm_batch_prefill(
block_tables,
seq_lens_gpu,
max_q_len,
max_seq_len,
q_scale * k_scale * sm_scale, # bmm1_scale
v_scale / o_scale, # bmm2_scale
batch_size,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_trtllm_gen_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_trtllm_batch_decode_fmha(
workspace_buffer,
block_tables,
seq_lens_gpu,
max_seq_len,
# max_seq_len,
q_scale * k_scale * sm_scale, # bmm1_scale
v_scale / o_scale, # bmm2_scale
window_left, # window_left
Expand Down Expand Up @@ -433,7 +433,7 @@ def test_trtllm_batch_decode_mla(
# Sequence lengths and block tables
seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
# max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)

blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
Expand Down Expand Up @@ -497,7 +497,6 @@ def test_trtllm_batch_decode_mla(
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=scale / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
Expand Down