Skip to content

Commit fcae852

Browse files
authored
[None][fix] Fix KV cache clearing with KV Connector API (#8750)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
1 parent 1a78e7a commit fcae852

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,6 @@ def _kv_connector_terminate_requests(self):
11131113
else:
11141114
self.ctx_in_transmission_requests[req.py_request_id] = (
11151115
request, block_id, counter - 1)
1116-
break
11171116

11181117
def _kv_connector_wait_for_save(self):
11191118
if self.kv_connector_manager is not None:

tests/integration/defs/llmapi/test_llm_api_connector.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,16 @@ def model_with_connector():
4242
)
4343

4444
def model_fn(*args, **kwargs):
45-
return LLM(
46-
*args,
47-
**kwargs,
48-
model=f"{llm_models_root()}/Qwen2-0.5B",
49-
backend="pytorch",
50-
kv_connector_config=kv_connector_config,
51-
cuda_graph_config=None,
52-
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
53-
)
45+
46+
default_kwargs = {
47+
"model": f"{llm_models_root()}/Qwen2-0.5B",
48+
"backend": "pytorch",
49+
"kv_connector_config": kv_connector_config,
50+
"cuda_graph_config": None,
51+
"kv_cache_config": KvCacheConfig(free_gpu_memory_fraction=0.1)
52+
}
53+
54+
return LLM(*args, **{**default_kwargs, **kwargs})
5455

5556
yield model_fn, mock_scheduler, mock_worker
5657

@@ -399,3 +400,27 @@ def test_connector_disagg_prefill(enforce_single_worker, model_with_connector,
399400
assert len(req.new_tokens) == 48
400401

401402
assert scheduler.request_finished.call_count == 1
403+
404+
405+
@pytest.mark.threadleak(enabled=False)
406+
def test_connector_multi_request(enforce_single_worker, model_with_connector):
407+
model_fn, scheduler, worker = model_with_connector
408+
409+
model = model_fn(disable_overlap_scheduler=True,
410+
kv_cache_config=KvCacheConfig(max_tokens=120))
411+
412+
sampling_params = SamplingParams(ignore_eos=True, max_tokens=4)
413+
414+
scheduler.get_num_new_matched_tokens.return_value = 0, False
415+
scheduler.request_finished.return_value = True
416+
worker.get_finished.side_effect = lambda finished_gen, load_async: (
417+
finished_gen, load_async)
418+
419+
model.generate([[0] * 48, [1] * 48],
420+
sampling_params=[
421+
SamplingParams(ignore_eos=True, max_tokens=4),
422+
SamplingParams(ignore_eos=True, max_tokens=3)
423+
])
424+
425+
# The KV cache of both prior requests should be freed, allowing the third request to run.
426+
model.generate([2] * 110, sampling_params=sampling_params)

0 commit comments

Comments
 (0)