Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,23 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
balanced_context_requests = context_requests
return balanced_context_requests

@staticmethod
def _compute_scheduled_tokens(context_requests, generation_requests):
"""Compute the total number of scheduled tokens for batch waiting decisions.

For context requests, KV cache reusable tokens are subtracted (only for
the first context chunk), with a minimum of 1 compute token per request.
For generation requests, each contributes 1 + num_draft_tokens.
"""
num_scheduled_ctx_tokens = 0
for ctx_req in context_requests:
req_tokens = len(ctx_req.get_tokens(0))
reusable = ctx_req.estimated_reusable_tokens if ctx_req.is_first_context_chunk else 0
num_scheduled_ctx_tokens += max(1, req_tokens - reusable)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
return num_scheduled_ctx_tokens + num_scheduled_gen_tokens

def _waiting_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
"""
Expand All @@ -2740,11 +2757,8 @@ def _waiting_requests(self, context_requests: list[LlmRequest],
- The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`.
"""

num_scheduled_ctx_tokens = sum(
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
num_scheduled_tokens = self._compute_scheduled_tokens(
context_requests, generation_requests)

should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens
if should_waiting:
Expand Down
77 changes: 77 additions & 0 deletions tests/unittest/_torch/executor/test_py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SHUTDOWN_REQUEST_ID,
RequestQueueItem,
)
from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor
from tensorrt_llm._torch.pyexecutor.scheduler import FCFSWaitingQueue


Expand Down Expand Up @@ -178,3 +179,79 @@ def test_getter_methods(mock_executor):
assert mock_executor.get_expected_num_active_requests() == 5
assert mock_executor._get_new_active_requests_queue_latency() == 10.5
assert mock_executor.get_waiting_queue_size() == 1


# ---------------------------------------------------------------------------
# Tests for _waiting_requests with KV cache reuse awareness
# ---------------------------------------------------------------------------


def _make_ctx_request(num_tokens, estimated_reusable_tokens=0, is_first_context_chunk=True):
"""Helper to create a mock context request."""
req = Mock()
req.get_tokens = Mock(return_value=list(range(num_tokens)))
req.estimated_reusable_tokens = estimated_reusable_tokens
req.is_first_context_chunk = is_first_context_chunk
return req


class MockPyExecutorForWaiting:
"""Mock for testing _waiting_requests.

Calls PyExecutor._waiting_requests directly to avoid mirroring logic.
"""

# Expose the static method so self._compute_scheduled_tokens works
# when PyExecutor._waiting_requests is called with this mock as self.
_compute_scheduled_tokens = staticmethod(PyExecutor._compute_scheduled_tokens)

def __init__(
self, max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5, batch_wait_timeout_iters=3
):
self.max_num_tokens = max_num_tokens
self.batch_wait_max_tokens_ratio = batch_wait_max_tokens_ratio
self.batch_wait_timeout_iters = batch_wait_timeout_iters
self.batch_wait_iters_count = 0

def _waiting_requests(self, context_requests, generation_requests):
return PyExecutor._waiting_requests(self, context_requests, generation_requests)


class TestWaitingRequests:
def test_no_reuse_counts_all_tokens(self):
"""Without KV cache reuse, all context tokens are counted."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 100 tokens < 500 threshold => should wait
ctx_reqs = [_make_ctx_request(100, estimated_reusable_tokens=0)]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # waiting

def test_reuse_reduces_token_count(self):
"""With KV cache reuse, only compute tokens are counted."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 600 total tokens, 500 reusable => 100 compute tokens < 500 threshold
ctx_reqs = [
_make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=True)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # waiting because compute tokens = 100

def test_reuse_not_applied_for_non_first_chunk(self):
"""Reusable tokens are ignored for non-first context chunks."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 600 tokens, reusable=500 but is_first_context_chunk=False => counts all 600
ctx_reqs = [
_make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=False)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == ctx_reqs # not waiting, 600 >= 500

def test_compute_tokens_at_least_one(self):
"""Each request contributes at least 1 compute token."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 100 tokens, 100 reusable => max(1, 0) = 1 compute token
ctx_reqs = [
_make_ctx_request(100, estimated_reusable_tokens=100, is_first_context_chunk=True)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # 1 token < 500, should wait
Loading