Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2723,8 +2723,11 @@ def _waiting_requests(self, context_requests: list[LlmRequest],
- The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`.
"""

num_scheduled_ctx_tokens = sum(
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
num_scheduled_ctx_tokens = 0
for ctx_req in context_requests:
req_tokens = len(ctx_req.get_tokens(0))
reusable = ctx_req.estimated_reusable_tokens if ctx_req.is_first_context_chunk else 0
num_scheduled_ctx_tokens += max(1, req_tokens - reusable)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
Expand Down
89 changes: 89 additions & 0 deletions tests/unittest/_torch/executor/test_py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,92 @@ def test_getter_methods(mock_executor):
assert mock_executor.get_expected_num_active_requests() == 5
assert mock_executor._get_new_active_requests_queue_latency() == 10.5
assert mock_executor.get_waiting_queue_size() == 1


# ---------------------------------------------------------------------------
# Tests for _waiting_requests with KV cache reuse awareness
# ---------------------------------------------------------------------------


def _make_ctx_request(num_tokens, estimated_reusable_tokens=0, is_first_context_chunk=True):
"""Helper to create a mock context request."""
req = Mock()
req.get_tokens = Mock(return_value=list(range(num_tokens)))
req.estimated_reusable_tokens = estimated_reusable_tokens
req.is_first_context_chunk = is_first_context_chunk
return req


class MockPyExecutorForWaiting:
"""Mock for testing _waiting_requests."""

def __init__(
self, max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5, batch_wait_timeout_iters=3
):
self.max_num_tokens = max_num_tokens
self.batch_wait_max_tokens_ratio = batch_wait_max_tokens_ratio
self.batch_wait_timeout_iters = batch_wait_timeout_iters
self.batch_wait_iters_count = 0

def _waiting_requests(self, context_requests, generation_requests):
"""Mirror of PyExecutor._waiting_requests."""
num_scheduled_ctx_tokens = 0
for ctx_req in context_requests:
req_tokens = len(ctx_req.get_tokens(0))
reusable = ctx_req.estimated_reusable_tokens if ctx_req.is_first_context_chunk else 0
num_scheduled_ctx_tokens += max(1, req_tokens - reusable)
num_scheduled_gen_tokens = sum(
1 + gen_req.num_draft_tokens for gen_req in generation_requests
)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens

should_waiting = (
self.batch_wait_iters_count < self.batch_wait_timeout_iters
and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens
)
if should_waiting:
self.batch_wait_iters_count += 1
return []

self.batch_wait_iters_count = 0
return context_requests


class TestWaitingRequests:
def test_no_reuse_counts_all_tokens(self):
"""Without KV cache reuse, all context tokens are counted."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 100 tokens < 500 threshold => should wait
ctx_reqs = [_make_ctx_request(100, estimated_reusable_tokens=0)]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # waiting

def test_reuse_reduces_token_count(self):
"""With KV cache reuse, only compute tokens are counted."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 600 total tokens, 500 reusable => 100 compute tokens < 500 threshold
ctx_reqs = [
_make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=True)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # waiting because compute tokens = 100

def test_reuse_not_applied_for_non_first_chunk(self):
"""Reusable tokens are ignored for non-first context chunks."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 600 tokens, reusable=500 but is_first_context_chunk=False => counts all 600
ctx_reqs = [
_make_ctx_request(600, estimated_reusable_tokens=500, is_first_context_chunk=False)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == ctx_reqs # not waiting, 600 >= 500

def test_compute_tokens_at_least_one(self):
"""Each request contributes at least 1 compute token."""
executor = MockPyExecutorForWaiting(max_num_tokens=1000, batch_wait_max_tokens_ratio=0.5)
# 100 tokens, 100 reusable => max(1, 0) = 1 compute token
ctx_reqs = [
_make_ctx_request(100, estimated_reusable_tokens=100, is_first_context_chunk=True)
]
result = executor._waiting_requests(ctx_reqs, [])
assert result == [] # 1 token < 500, should wait
Loading