Skip to content

Commit 32f5391

Browse files
authored
[TRTLLM-909][feat] Overlap context chunks in pipeline parallel mode (#9308)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent afc52d7 commit 32f5391

File tree

7 files changed

+290
-104
lines changed

7 files changed

+290
-104
lines changed

cpp/tensorrt_llm/batch_manager/pauseRequests.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ void tensorrt_llm::batch_manager::PauseRequests::operator()(RequestVector& reque
5050
for (auto& llmReq : requestsToPause)
5151
{
5252
auto const reqId = llmReq->mRequestId;
53-
inflightReqIds.erase(reqId);
54-
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set", reqId);
53+
auto const removed = inflightReqIds.erase(reqId);
54+
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set: %d", reqId, removed);
5555

5656
// If a request in this context had been flagged to be paused, pause it right away
5757
if (reqIdsToPause.find(reqId) != reqIdsToPause.end())

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,6 @@ void TrtGptModelInflightBatching::forwardSync()
881881
}
882882
}
883883

884-
(*mPauseRequests)(currRequests.contextRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
885-
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
886884
(*mPauseRequests)(currRequests.generationRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
887885
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
888886

@@ -1051,14 +1049,23 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
10511049
{
10521050
NVTX3_SCOPED_RANGE(updateInflightReqIds);
10531051
// Add requests to in-flight set, so they can be skipped in other micro batches
1054-
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
1052+
for (auto const& llmReq : currRequests.contextRequests)
10551053
{
1056-
for (auto const& llmReq : requests)
1054+
// Context requests that are chunking are not added to inflight set, so they are scheduled in the
1055+
// next micro batch.
1056+
if (llmReq->isLastContextChunk())
10571057
{
1058-
TLLM_LOG_DEBUG("request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
1058+
TLLM_LOG_DEBUG(
1059+
"Context request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
10591060
mInflightReqIds.insert(llmReq->mRequestId);
10601061
}
10611062
}
1063+
for (auto const& llmReq : currRequests.generationRequests)
1064+
{
1065+
TLLM_LOG_DEBUG(
1066+
"Generation request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
1067+
mInflightReqIds.insert(llmReq->mRequestId);
1068+
}
10621069
}
10631070

10641071
(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,10 @@ def create_py_executor_instance(
678678

679679
spec_config = model_engine.spec_config
680680

681+
max_num_sequences = max_batch_size * mapping.pp_size
682+
681683
logger.info(
682-
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
684+
f"max_seq_len={max_seq_len}, max_num_requests={max_num_sequences}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
683685
)
684686

685687
for key, value in llm_args.extra_resource_managers.items():
@@ -764,8 +766,6 @@ def create_py_executor_instance(
764766
lora_config.trtllm_modules_to_hf_modules,
765767
lora_config.swap_gate_up_proj_lora_b_weight)
766768

767-
max_num_sequences = max_batch_size * mapping.pp_size
768-
769769
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
770770
max_num_sequences)
771771

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class BatchState:
106106
class BatchStatePP(BatchState):
107107
microbatch_id: int = -1
108108
scheduled_ctx_reqs: list[LlmRequest] = None
109+
finished_ctx_reqs: list[LlmRequest] = None
109110

110111

111112
class PyExecutor:
@@ -232,6 +233,8 @@ def __init__(self,
232233
| None] = [None] * self.num_micro_batches
233234
self.send_handles = [None] * self.num_micro_batches
234235

236+
# Set of request IDs that are currently in flight across all micro batches.
237+
# The scheduler will avoid scheduling requests that are already in flight.
235238
self.inflight_req_ids = ReqIdsSet()
236239

237240
# During warmup, we don't enable the profiler
@@ -694,7 +697,7 @@ def get_queued_req_stats(request_id: int) -> RequestStats:
694697
return req_stats
695698

696699
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
697-
scheduled_batch) -> IterationStats:
700+
scheduled_batch, micro_batch_id) -> IterationStats:
698701
stats.iter_latency_ms = iter_latency_ms
699702

700703
stats.num_queued_requests = self.executor_request_queue.get_request_queue_size(
@@ -735,7 +738,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
735738
stats.inflight_batching_stats.num_paused_requests = len(
736739
scheduled_batch.paused_requests)
737740
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
738-
stats.inflight_batching_stats.micro_batch_id = 0
741+
stats.inflight_batching_stats.micro_batch_id = micro_batch_id
739742
if stats.specdec_stats is not None:
740743
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
741744
stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms)
@@ -748,9 +751,13 @@ def _append_iter_stats(self,
748751
with self.stats_lock:
749752
self.stats.append((stats, req_stats))
750753

751-
def _process_iter_stats(self, finished_requests: list[LlmRequest],
752-
active_requests: List[LlmRequest],
753-
batch_state: BatchState):
754+
def _process_iter_stats(
755+
self,
756+
finished_requests: list[LlmRequest],
757+
active_requests: List[LlmRequest],
758+
batch_state: BatchState,
759+
micro_batch_id: int = 0,
760+
):
754761
iter_end_time = time.time()
755762
iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3
756763
if batch_state.iter_stats is None:
@@ -763,9 +770,10 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
763770
and self.enable_iter_perf_stats) else None
764771

765772
self._append_iter_stats(
766-
self._update_iter_stats(
767-
batch_state.iter_stats, iter_latency_ms, len(finished_requests),
768-
batch_state.sample_state.scheduled_requests), req_stats)
773+
self._update_iter_stats(batch_state.iter_stats, iter_latency_ms,
774+
len(finished_requests),
775+
batch_state.sample_state.scheduled_requests,
776+
micro_batch_id), req_stats)
769777

770778
def _executor_loop_cleanup(self):
771779

@@ -825,6 +833,7 @@ def _executor_loop_pp(self):
825833
self.num_scheduled_requests = scheduled_batch.batch_size
826834

827835
logger.debug(
836+
f'iteration {self.iter_counter}, microbatch {microbatch_id}, '
828837
f'has {len(self.active_requests)} active_requests, '
829838
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
830839
f'{len(scheduled_batch.generation_requests)} generation requests'
@@ -833,9 +842,13 @@ def _executor_loop_pp(self):
833842
can_queue = self._can_queue(scheduled_batch)
834843

835844
if not can_queue:
845+
logger.debug(
846+
f"microbatch {microbatch_id} cannot be queued, skipping"
847+
)
836848
self.micro_batches[microbatch_id] = None
837849
else:
838-
self._add_inflight_ids(scheduled_batch)
850+
logger.debug(f"microbatch {microbatch_id} can be queued")
851+
finished_ctx_reqs = self._add_inflight_ids(scheduled_batch)
839852

840853
if self.kv_cache_transceiver:
841854
# For generation requests which have completed KV cache transfer
@@ -895,6 +908,7 @@ def _executor_loop_pp(self):
895908
iter_stats=iter_stats,
896909
microbatch_id=microbatch_id,
897910
scheduled_ctx_reqs=scheduled_batch.context_requests,
911+
finished_ctx_reqs=finished_ctx_reqs,
898912
)
899913

900914
self.micro_batches[microbatch_id] = batch_state
@@ -945,6 +959,8 @@ def _executor_loop_pp(self):
945959
finished_requests = []
946960
if previous_batch is not None:
947961
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
962+
sample_state = previous_batch.sample_state
963+
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
948964
self._update_requests(previous_batch.sample_state)
949965

950966
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -976,7 +992,8 @@ def _executor_loop_pp(self):
976992
self.resource_manager.update_resources(
977993
previous_scheduled_batch, attn_metadata,
978994
kv_cache_dtype_byte_size)
979-
self._remove_inflight_ids(previous_scheduled_batch)
995+
996+
self._remove_inflight_ids(previous_batch)
980997

981998
self.wait_on_pp_send_handles(prev_microbatch_id)
982999
self.micro_batches[prev_microbatch_id] = None
@@ -993,9 +1010,11 @@ def _executor_loop_pp(self):
9931010
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
9941011

9951012
if self.enable_iter_perf_stats and previous_batch is not None:
1013+
sample_state = previous_batch.sample_state
1014+
sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs
9961015
self._process_iter_stats(finished_requests,
9971016
self.active_requests,
998-
previous_batch)
1017+
previous_batch, microbatch_id)
9991018

10001019
self.iter_counter += 1
10011020

@@ -2481,13 +2500,43 @@ def _pause_requests(self, requests_to_pause):
24812500
self._terminate_request(req)
24822501

24832502
def _add_inflight_ids(self, scheduled_requests):
2484-
"""Add reqids of current requests to self.inflight_req_ids."""
2485-
for req in scheduled_requests.all_requests():
2503+
"""Add request IDs of current requests to self.inflight_req_ids.
2504+
2505+
Non‑final context chunks are not added to the inflight set, so the scheduler can keep scheduling further
2506+
context chunks while earlier ones are in the PP pipeline. Only context requests that finish context phase
2507+
are inserted into the inflight set and collected into finished_ctx_reqs.
2508+
All generation requests are still inserted into the inflight set.
2509+
"""
2510+
finished_ctx_reqs = []
2511+
for req in scheduled_requests.context_requests:
2512+
if req.is_last_context_chunk:
2513+
logger.debug(
2514+
f"Context request with ID {req.request_id} added to DECODER model inflight set"
2515+
)
2516+
self.inflight_req_ids.insert(req.request_id)
2517+
finished_ctx_reqs.append(req)
2518+
for req in scheduled_requests.generation_requests:
2519+
logger.debug(
2520+
f"Generation request with ID {req.request_id} added to DECODER model inflight set"
2521+
)
24862522
self.inflight_req_ids.insert(req.request_id)
2523+
return finished_ctx_reqs
2524+
2525+
def _remove_inflight_ids(self, batch_state: BatchStatePP):
2526+
"""Remove request IDs of current requests from self.inflight_req_ids.
24872527
2488-
def _remove_inflight_ids(self, scheduled_requests):
2489-
"""Remove reqids of current requests from self.inflight_req_ids."""
2490-
for req in scheduled_requests.all_requests():
2528+
Context IDs are erased from the inflight set using batch_state.finished_ctx_reqs.
2529+
Generation IDs are erased using batch_state.sample_state.scheduled_requests.generation_requests.
2530+
"""
2531+
for req in batch_state.finished_ctx_reqs:
2532+
logger.debug(
2533+
f"Context request with ID {req.request_id} removed from DECODER model inflight set"
2534+
)
2535+
self.inflight_req_ids.erase(req.request_id)
2536+
for req in batch_state.sample_state.scheduled_requests.generation_requests:
2537+
logger.debug(
2538+
f"Generation request with ID {req.request_id} removed from DECODER model inflight set"
2539+
)
24912540
self.inflight_req_ids.erase(req.request_id)
24922541

24932542
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,

0 commit comments

Comments
 (0)