Skip to content

Commit 33ba674

Browse files
committed
[TRTLLM-8922][fix] Add type annotation for _disaggregated_params member
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 94dc566 commit 33ba674

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def cancel_request(self, req: LlmRequest):
207207
return self.impl.cancel_request(req)
208208

209209
def prepare_context_requests(self, requests: List[LlmRequest]):
210-
raise NotImplementedError
210+
# not implemented, an empty placeholder to allow being invoked unconditionally
211+
...
211212

212213
def get_disaggregated_params(self):
213214
# Cpp kv cache transceiver will set the disaggregated params to context response

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def _pp_schedule_and_propagate(self, microbatch_id: int):
11731173
is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp
11741174
if self.dist.rank == 0 or (self.dist.is_first_pp_rank
11751175
and is_dp_broadcast):
1176-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs, all_gen_first = self._schedule(
1176+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
11771177
)
11781178
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
11791179
scheduled_batch, fitting_disagg_gen_init_requests,
@@ -1301,6 +1301,11 @@ def _executor_loop_pp(self):
13011301
self._prepare_disagg_gen_init(
13021302
fitting_disagg_gen_init_requests)
13031303

1304+
all_gen_first = self.active_requests and all(
1305+
req.py_disaggregated_params
1306+
and req.py_disaggregated_params.schedule_style ==
1307+
DisaggScheduleStyle.GENERATION_FIRST
1308+
for req in self.active_requests)
13041309
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first:
13051310
logger.warning(
13061311
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -1605,7 +1610,7 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
16051610
# _handle_responses sees the request before it is terminated.
16061611
if self.kv_cache_transceiver:
16071612
self._check_disagg_ctx_cache_transfer_status(0)
1608-
sample_state_scheduled_requests = executed_batch.sample_state.scheduled_requests
1613+
sample_state_scheduled_requests = executed_batch.scheduled_requests
16091614
attn_metadata = getattr(self.model_engine, 'attn_metadata',
16101615
None)
16111616
kv_cache_dtype_byte_size = getattr(self.model_engine,
@@ -1787,7 +1792,7 @@ def _prepare_and_schedule_batch(self):
17871792
# that speculation is about to happen.
17881793
self._prepare_draft_requests()
17891794

1790-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs, all_gen_first = self._schedule(
1795+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
17911796
)
17921797

17931798
if self.drafter is not None and not self.use_spec_decode:
@@ -1798,6 +1803,10 @@ def _prepare_and_schedule_batch(self):
17981803
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
17991804
self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests)
18001805

1806+
all_gen_first = self.active_requests and all(
1807+
req.py_disaggregated_params and req.py_disaggregated_params.
1808+
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
1809+
for req in self.active_requests)
18011810
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first:
18021811
logger.warning(
18031812
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -2773,12 +2782,7 @@ def _schedule(self):
27732782
scheduled_requests.generation_requests = scheduler_output.generation_requests
27742783
scheduled_requests.paused_requests = scheduler_output.paused_requests
27752784

2776-
all_gen_first = self.active_requests and all(
2777-
req.py_disaggregated_params and req.py_disaggregated_params.
2778-
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
2779-
for req in self.active_requests)
2780-
2781-
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests, all_gen_first
2785+
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests
27822786

27832787
@nvtx_range("_check_disagg_gen_transfer_status")
27842788
def _check_disagg_gen_transfer_status(self):
@@ -2839,17 +2843,16 @@ def _check_disagg_ctx_schedulable_status(self,
28392843
"""
28402844
if not self.kv_cache_transceiver:
28412845
return
2842-
ctx_only_requests = [
2846+
gen_first_ctx_requests = [
28432847
req for req in new_requests
28442848
if req.is_context_only_request and req.py_disaggregated_params.
28452849
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
28462850
]
28472851
# Always call prepare_context_requests when there are new requests
28482852
# or previously-waiting requests, so the tp_allgather consensus
28492853
# can promote requests whose peer info has arrived on all ranks.
2850-
if ctx_only_requests or self.kv_cache_transceiver.wait_req_id_to_request:
2851-
self.kv_cache_transceiver.prepare_context_requests(
2852-
ctx_only_requests)
2854+
self.kv_cache_transceiver.prepare_context_requests(
2855+
gen_first_ctx_requests)
28532856

28542857
@nvtx_range("_pad_attention_dp_dummy_request")
28552858
def _pad_attention_dp_dummy_request(self):

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(self,
150150
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
151151
self._orchestrator_type = kwargs.get("orchestrator_type", None)
152152
self._llm_id = None
153-
self._disaggregated_params = None
153+
self._disaggregated_params: Optional[dict] = None
154154

155155
log_level = logger.level
156156
logger.set_level("info") # force display the backend

0 commit comments

Comments
 (0)