Skip to content

Commit f9adac3

Browse files
authored
[feat] Enable chunked context for flashinfer (#4132)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 0fd59d6 commit f9adac3

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,6 @@ def prepare(self) -> None:
186186
assert self.request_ids is not None
187187
block_ids_per_seq = self.kv_cache_manager.get_batch_cache_indices(
188188
self.request_ids)
189-
paged_kv_indices = torch.tensor(
190-
[x for block_ids in block_ids_per_seq for x in block_ids],
191-
dtype=torch.int32)
192-
self._paged_kv_indices[:paged_kv_indices.size(0)].copy_(
193-
paged_kv_indices, non_blocking=True)
194189

195190
# number of tokens in the kv cache for each sequence in the batch
196191
cached_token_lens = torch.tensor(
@@ -212,13 +207,26 @@ def prepare(self) -> None:
212207
1])
213208

214209
# number of cache blocks used by each sequence in the cache
215-
self.num_blocks = [len(block_ids) for block_ids in block_ids_per_seq]
210+
# NOTE: do not use len(block_ids) - that will give you a number
211+
# that can be too big if using chunked prefill/kv cache reuse
212+
# since we allocate all blocks ahead of time.
213+
num_blocks = ((kv_lens + self.page_size - 1) // self.page_size)
214+
self.num_blocks = num_blocks.tolist()
216215
self.num_context_blocks = sum(self.num_blocks[:self.num_contexts])
217216
self.num_generation_blocks = sum(self.num_blocks[self.num_contexts:])
218217

218+
paged_kv_indices_list = []
219+
for i, block_ids in enumerate(block_ids_per_seq):
220+
paged_kv_indices_list.extend(block_ids[:self.num_blocks[i]])
221+
222+
paged_kv_indices = torch.tensor(paged_kv_indices_list,
223+
dtype=torch.int32)
224+
225+
self._paged_kv_indices[:paged_kv_indices.size(0)].copy_(
226+
paged_kv_indices, non_blocking=True)
227+
219228
# number of tokens in the last cache block used by each sequence
220-
paged_kv_last_page_len = kv_lens - (torch.Tensor(
221-
self.num_blocks).int().cuda(non_blocking=True) - 1) * self.page_size
229+
paged_kv_last_page_len = kv_lens - (num_blocks - 1) * self.page_size
222230
self._paged_kv_last_page_len[:paged_kv_last_page_len.size(0)].copy_(
223231
paged_kv_last_page_len, non_blocking=True)
224232

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def create_py_executor(executor_config: ExecutorConfig,
4646
)
4747
executor_config.kv_cache_config.enable_block_reuse = False
4848

49-
if pytorch_backend_config.attn_backend in [
50-
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
51-
] and executor_config.enable_chunked_context:
49+
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and executor_config.enable_chunked_context:
5250
logger.warning(
5351
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"
5452
)

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
5454
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
5555
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
5656

57+
@pytest.mark.skip_less_device_memory(32000)
58+
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
59+
def test_chunked_prefill(self, attn_backend):
60+
pytorch_config = PyTorchConfig(attn_backend=attn_backend, )
61+
llm = LLM(self.MODEL_PATH,
62+
enable_chunked_prefill=True,
63+
max_num_tokens=64,
64+
pytorch_backend_config=pytorch_config)
65+
with llm:
66+
task = MMLU(self.MODEL_NAME)
67+
task.evaluate(llm)
68+
task = GSM8K(self.MODEL_NAME)
69+
task.evaluate(llm)
70+
5771
@pytest.mark.skip_less_device_memory(32000)
5872
@parametrize_with_ids("torch_compile", [False, True])
5973
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ l0_dgx_h100:
3131
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
3232
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
3333
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=True-attn_backend=TRTLLM-torch_compile=True]
34+
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM]
35+
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER]
3436
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
3537
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
3638
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]

0 commit comments

Comments
 (0)