Skip to content

Commit df05160

Browse files
committed
[TRTLLM-8922][fix] Add type annotation for _disaggregated_params member
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 3b62d78 commit df05160

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def cancel_request(self, req: LlmRequest):
207207
return self.impl.cancel_request(req)
208208

209209
def prepare_context_requests(self, requests: List[LlmRequest]):
210-
raise NotImplementedError
210+
# not implemented, an empty placeholder to allow being invoked unconditionally
211+
...
211212

212213
def get_disaggregated_params(self):
213214
# Cpp kv cache transceiver will set the disaggregated params to context response

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,7 @@ def _pp_schedule_and_propagate(self, microbatch_id: int):
11521152
is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp
11531153
if self.dist.rank == 0 or (self.dist.is_first_pp_rank
11541154
and is_dp_broadcast):
1155-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs, all_gen_first = self._schedule(
1155+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
11561156
)
11571157
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
11581158
scheduled_batch, fitting_disagg_gen_init_requests,
@@ -1280,6 +1280,11 @@ def _executor_loop_pp(self):
12801280
self._prepare_disagg_gen_init(
12811281
fitting_disagg_gen_init_requests)
12821282

1283+
all_gen_first = self.active_requests and all(
1284+
req.py_disaggregated_params
1285+
and req.py_disaggregated_params.schedule_style ==
1286+
DisaggScheduleStyle.GENERATION_FIRST
1287+
for req in self.active_requests)
12831288
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first:
12841289
logger.warning(
12851290
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -1584,7 +1589,7 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
15841589
# _handle_responses sees the request before it is terminated.
15851590
if self.kv_cache_transceiver:
15861591
self._check_disagg_ctx_cache_transfer_status(0)
1587-
sample_state_scheduled_requests = executed_batch.sample_state.scheduled_requests
1592+
sample_state_scheduled_requests = executed_batch.scheduled_requests
15881593
attn_metadata = getattr(self.model_engine, 'attn_metadata',
15891594
None)
15901595
kv_cache_dtype_byte_size = getattr(self.model_engine,
@@ -1766,7 +1771,7 @@ def _prepare_and_schedule_batch(self):
17661771
# that speculation is about to happen.
17671772
self._prepare_draft_requests()
17681773

1769-
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs, all_gen_first = self._schedule(
1774+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
17701775
)
17711776

17721777
if self.drafter is not None and not self.use_spec_decode:
@@ -1777,6 +1782,10 @@ def _prepare_and_schedule_batch(self):
17771782
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
17781783
self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests)
17791784

1785+
all_gen_first = self.active_requests and all(
1786+
req.py_disaggregated_params and req.py_disaggregated_params.
1787+
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
1788+
for req in self.active_requests)
17801789
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first:
17811790
logger.warning(
17821791
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -2719,12 +2728,7 @@ def _schedule(self):
27192728
scheduled_requests.generation_requests = scheduler_output.generation_requests
27202729
scheduled_requests.paused_requests = scheduler_output.paused_requests
27212730

2722-
all_gen_first = self.active_requests and all(
2723-
req.py_disaggregated_params and req.py_disaggregated_params.
2724-
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
2725-
for req in self.active_requests)
2726-
2727-
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests, all_gen_first
2731+
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests
27282732

27292733
@nvtx_range("_check_disagg_gen_transfer_status")
27302734
def _check_disagg_gen_transfer_status(self):
@@ -2785,17 +2789,16 @@ def _check_disagg_ctx_schedulable_status(self,
27852789
"""
27862790
if not self.kv_cache_transceiver:
27872791
return
2788-
ctx_only_requests = [
2792+
gen_first_ctx_requests = [
27892793
req for req in new_requests
27902794
if req.is_context_only_request and req.py_disaggregated_params.
27912795
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
27922796
]
27932797
# Always call prepare_context_requests when there are new requests
27942798
# or previously-waiting requests, so the tp_allgather consensus
27952799
# can promote requests whose peer info has arrived on all ranks.
2796-
if ctx_only_requests or self.kv_cache_transceiver.wait_req_id_to_request:
2797-
self.kv_cache_transceiver.prepare_context_requests(
2798-
ctx_only_requests)
2800+
self.kv_cache_transceiver.prepare_context_requests(
2801+
gen_first_ctx_requests)
27992802

28002803
@nvtx_range("_pad_attention_dp_dummy_request")
28012804
def _pad_attention_dp_dummy_request(self):

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(self,
150150
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
151151
self._orchestrator_type = kwargs.get("orchestrator_type", None)
152152
self._llm_id = None
153-
self._disaggregated_params = None
153+
self._disaggregated_params: Optional[dict] = None
154154

155155
log_level = logger.level
156156
logger.set_level("info") # force display the backend

0 commit comments

Comments
 (0)