diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 13076973068..664f4c35811 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2731,6 +2731,23 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest], balanced_context_requests = context_requests return balanced_context_requests + @staticmethod + def _compute_scheduled_tokens(context_requests, generation_requests): + """Compute the total number of scheduled tokens for batch waiting decisions. + + For context requests, KV cache reusable tokens are subtracted (only for + the first context chunk), with a minimum of 1 compute token per request. + For generation requests, each contributes 1 + num_draft_tokens. + """ + num_scheduled_ctx_tokens = 0 + for ctx_req in context_requests: + req_tokens = len(ctx_req.get_tokens(0)) + reusable = ctx_req.estimated_reusable_tokens if ctx_req.is_first_context_chunk else 0 + num_scheduled_ctx_tokens += max(1, req_tokens - reusable) + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + return num_scheduled_ctx_tokens + num_scheduled_gen_tokens + def _waiting_requests(self, context_requests: list[LlmRequest], generation_requests: list[LlmRequest]): """ @@ -2740,11 +2757,8 @@ def _waiting_requests(self, context_requests: list[LlmRequest], - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`. """ - num_scheduled_ctx_tokens = sum( - len(ctx_req.get_tokens(0)) for ctx_req in context_requests) - num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens - for gen_req in generation_requests) - num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + num_scheduled_tokens = self._compute_scheduled_tokens( + context_requests, generation_requests) should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens if should_waiting: diff --git a/tests/unittest/_torch/executor/test_py_executor.py b/tests/unittest/_torch/executor/test_py_executor.py index 9fb59305980..518a41980ef 100644 --- a/tests/unittest/_torch/executor/test_py_executor.py +++ b/tests/unittest/_torch/executor/test_py_executor.py @@ -17,6 +17,7 @@ SHUTDOWN_REQUEST_ID, RequestQueueItem, ) +from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor from tensorrt_llm._torch.pyexecutor.scheduler import FCFSWaitingQueue @@ -178,3 +179,79 @@ def test_getter_methods(mock_executor): assert mock_executor.get_expected_num_active_requests() == 5 assert mock_executor._get_new_active_requests_queue_latency() == 10.5 assert mock_executor.get_waiting_queue_size() == 1 + + +# --------------------------------------------------------------------------- +# Tests for _waiting_requests with KV cache reuse awareness +# --------------------------------------------------------------------------- + + +def _make_ctx_request(num_tokens, estimated_reusable_tokens=0, is_first_context_chunk=True): + """Helper to create a mock context request.""" + req = Mock() + req.get_tokens = Mock(return_value=list(range(num_tokens))) + req.estimated_reusable_tokens = estimated_reusable_tokens + req.is_first_context_chunk = is_first_context_chunk + return req + + +class MockPyExecutorForWaiting: + """Mock for testing _waiting_requests. + + Calls PyExecutor._waiting_requests directly to avoid mirroring logic. + """ + + # Expose the static method so self._compute_scheduled_tokens works + # when PyExecutor._waiting_requests is called with this mock as self. + _compute_scheduled_tokens = staticmethod(PyExecutor._compute_scheduled_tokens) + + def __init__( + self, max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5, batch_wait_timeout_iters=3 + ): + self.max_num_tokens = max_num_tokens + self.batch_wait_max_tokens_ratio = batch_wait_max_tokens_ratio + self.batch_wait_timeout_iters = batch_wait_timeout_iters + self.batch_wait_iters_count = 0 + + def _waiting_requests(self, context_requests, generation_requests): + return PyExecutor._waiting_requests(self, context_requests, generation_requests) + + +class TestWaitingRequests: + def test_no_reuse_counts_all_tokens(self): + """Without KV cache reuse, all context tokens are counted.""" + executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5) + # 100 tokens < 500 threshold => should wait + ctx_reqs = [_make_ctx_request(100, estimated_reusable_tokens=0)] + result = executor._waiting_requests(ctx_reqs, []) + assert result == [] # waiting + + def test_reuse_reduces_token_count(self): + """With KV cache reuse, only compute tokens are counted.""" + executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5) + # 600 total tokens, 500 reusable => 100 compute tokens < 500 threshold + ctx_reqs = [ + _make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=True) + ] + result = executor._waiting_requests(ctx_reqs, []) + assert result == [] # waiting because compute tokens = 100 + + def test_reuse_not_applied_for_non_first_chunk(self): + """Reusable tokens are ignored for non-first context chunks.""" + executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5) + # 600 tokens, reusable=500 but is_first_context_chunk=False => counts all 600 + ctx_reqs = [ + _make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=False) + ] + result = executor._waiting_requests(ctx_reqs, []) + assert result == ctx_reqs # not waiting, 600 >= 500 + + def test_compute_tokens_at_least_one(self): + """Each request contributes at least 1 compute token.""" + executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5) + # 100 tokens, 100 reusable => max(1, 0) = 1 compute token + ctx_reqs = [ + _make_ctx_request(100, estimated_reusable_tokens=100, is_first_context_chunk=True) + ] + result = executor._waiting_requests(ctx_reqs, []) + assert result == [] # 1 token < 500, should wait