Skip to content

Commit d43be7b

Browse files
authored
[None][fix] Avoid Double update for previous batch (NVIDIA#9888)
Signed-off-by: Yi Zhang <[email protected]>
1 parent 944c304 commit d43be7b

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ void initBindings(nb::module_& m)
170170
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
171171
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
172172
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
173+
.def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
173174
.def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
174175
.def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished)
175176
.def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ void initBindings(pybind11::module_& m)
175175
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
176176
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
177177
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
178+
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
178179
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
179180
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
180181
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,7 @@ def _executor_loop_pp(self):
11611161
f'{len(scheduled_batch.generation_requests)} generation requests'
11621162
)
11631163

1164-
can_queue = self._can_queue(scheduled_batch)
1164+
can_queue, _ = self._can_queue(scheduled_batch)
11651165
if not can_queue:
11661166
logger.debug(
11671167
f"microbatch {microbatch_id} cannot be queued, skipping"
@@ -1359,13 +1359,17 @@ def wait_on_pp_send_handles(self, microbatch_id):
13591359

13601360
def _can_queue(self, scheduled_batch):
13611361

1362+
# can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks
1363+
# For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2.
1364+
# 1 for dummy request, 1 for the to complete but haven't updated request.
13621365
if self.enable_attention_dp:
13631366
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
13641367
can_queue = 0 not in tp_batch_sizes
1368+
can_queue_this_rank = scheduled_batch.batch_size > 0
13651369
else:
1366-
can_queue = scheduled_batch.batch_size > 0
1370+
can_queue = can_queue_this_rank = scheduled_batch.batch_size > 0
13671371

1368-
return can_queue
1372+
return can_queue, can_queue_this_rank
13691373

13701374
def _prepare_and_schedule_batch(self):
13711375
new_requests = self._fetch_and_activate_new_requests()
@@ -1494,7 +1498,7 @@ def _executor_loop(self):
14941498

14951499
finished_requests = []
14961500

1497-
can_queue = self._can_queue(scheduled_batch)
1501+
can_queue, _ = self._can_queue(scheduled_batch)
14981502
if can_queue:
14991503
if self.kv_cache_transceiver:
15001504
# For generation requests which have completed KV cache transfer
@@ -1509,7 +1513,7 @@ def _executor_loop(self):
15091513

15101514
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
15111515
if self.kv_connector_manager:
1512-
can_queue = self._can_queue(scheduled_batch)
1516+
can_queue, _ = self._can_queue(scheduled_batch)
15131517

15141518
if can_queue:
15151519
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
@@ -1711,7 +1715,8 @@ def _executor_loop_overlap(self):
17111715

17121716
self._pause_requests(scheduled_batch.paused_requests)
17131717

1714-
can_queue = self._can_queue(scheduled_batch)
1718+
can_queue, can_queue_this_rank = self._can_queue(
1719+
scheduled_batch)
17151720
if can_queue:
17161721
if self.kv_cache_transceiver:
17171722
# For generation requests which have completed KV cache transfer
@@ -1741,8 +1746,13 @@ def _executor_loop_overlap(self):
17411746

17421747
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
17431748
if self.kv_connector_manager:
1744-
can_queue = self._can_queue(scheduled_batch)
1749+
can_queue, can_queue_this_rank = self._can_queue(
1750+
scheduled_batch)
17451751

1752+
# If the batch is not empty on this rank, but empty on other ranks,
1753+
# we need to delay the update of the previous batch's sample state,
1754+
# and let the later iteration to update it.
1755+
should_process_previous_batch = can_queue or not can_queue_this_rank
17461756
if can_queue:
17471757

17481758
# The generation requests that are do not have batch_idx,
@@ -1792,10 +1802,10 @@ def _executor_loop_overlap(self):
17921802
scheduled_batch, previous_tensors_device,
17931803
num_accepted_tokens_device)
17941804

1795-
if self.previous_batch is not None:
1805+
if self.previous_batch is not None and should_process_previous_batch:
17961806
self._update_requests(self.previous_batch.sample_state)
17971807

1798-
if self.drafter is not None and self.use_spec_decode:
1808+
if self.drafter is not None and self.use_spec_decode and should_process_previous_batch:
17991809
# Cleanup previous draft resources used in the draft model
18001810
self.drafter.cleanup_previous_draft_resources()
18011811

@@ -1822,8 +1832,10 @@ def _executor_loop_overlap(self):
18221832
ctx_transmission_reqs = self._send_kv_async(
18231833
scheduled_batch.all_requests())
18241834

1825-
if self.previous_batch is not None:
1835+
if self.previous_batch is not None and should_process_previous_batch:
18261836
self._process_previous_batch()
1837+
else:
1838+
self._enqueue_responses([])
18271839

18281840
if can_queue:
18291841
if self.enable_iter_perf_stats:
@@ -1835,6 +1847,9 @@ def _executor_loop_overlap(self):
18351847
iter_start_time=iter_start_time,
18361848
iter_stats=iter_stats,
18371849
ctx_transmission_reqs=ctx_transmission_reqs)
1850+
elif not can_queue_this_rank:
1851+
# If the batch is empty on this rank, we need to clear the previous batch.
1852+
self.previous_batch = None
18381853

18391854
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
18401855
):
@@ -2194,10 +2209,10 @@ def _pad_attention_dp_dummy_request(self):
21942209
if self.kv_cache_transceiver is None:
21952210
num_active_request = len(self.active_requests)
21962211
else:
2197-
num_active_request = sum([
2198-
0 if req.is_disagg_generation_init_state
2199-
or req.is_disagg_generation_transmission_in_progress else 1
2200-
for req in self.active_requests
2212+
num_active_request = len([
2213+
req for req in self.active_requests
2214+
if not (req.is_disagg_generation_init_state
2215+
or req.is_disagg_generation_transmission_in_progress)
22012216
])
22022217

22032218
if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
@@ -2393,7 +2408,7 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
23932408

23942409
def _forward_step(
23952410
self,
2396-
scheduled_requests,
2411+
scheduled_requests: ScheduledRequests,
23972412
new_tensors_device: Optional[SampleStateTensors] = None,
23982413
num_accepted_tokens_device: Optional[torch.Tensor] = None):
23992414
ExpertStatistic.set_iter(self.iter_counter)

0 commit comments

Comments
 (0)