@@ -1173,7 +1173,7 @@ def _pp_schedule_and_propagate(self, microbatch_id: int):
11731173 is_dp_broadcast = self .dist .tp_size > 1 and self .enable_attention_dp
11741174 if self .dist .rank == 0 or (self .dist .is_first_pp_rank
11751175 and is_dp_broadcast ):
1176- scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs , all_gen_first = self ._schedule (
1176+ scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
11771177 )
11781178 serializable_schedule = SerializableSchedulerOutput .from_scheduler_result (
11791179 scheduled_batch , fitting_disagg_gen_init_requests ,
@@ -1301,6 +1301,11 @@ def _executor_loop_pp(self):
13011301 self ._prepare_disagg_gen_init (
13021302 fitting_disagg_gen_init_requests )
13031303
1304+ all_gen_first = self .active_requests and all (
1305+ req .py_disaggregated_params
1306+ and req .py_disaggregated_params .schedule_style ==
1307+ DisaggScheduleStyle .GENERATION_FIRST
1308+ for req in self .active_requests )
13041309 if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first :
13051310 logger .warning (
13061311 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -1605,7 +1610,7 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
16051610 # _handle_responses sees the request before it is terminated.
16061611 if self .kv_cache_transceiver :
16071612 self ._check_disagg_ctx_cache_transfer_status (0 )
1608- sample_state_scheduled_requests = executed_batch .sample_state . scheduled_requests
1613+ sample_state_scheduled_requests = executed_batch .scheduled_requests
16091614 attn_metadata = getattr (self .model_engine , 'attn_metadata' ,
16101615 None )
16111616 kv_cache_dtype_byte_size = getattr (self .model_engine ,
@@ -1787,7 +1792,7 @@ def _prepare_and_schedule_batch(self):
17871792 # that speculation is about to happen.
17881793 self ._prepare_draft_requests ()
17891794
1790- scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs , all_gen_first = self ._schedule (
1795+ scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
17911796 )
17921797
17931798 if self .drafter is not None and not self .use_spec_decode :
@@ -1798,6 +1803,10 @@ def _prepare_and_schedule_batch(self):
17981803 # For requests that are fitting disagg gen init, also prepare resources for KV cache manager
17991804 self ._prepare_disagg_gen_init (fitting_disagg_gen_init_requests )
18001805
1806+ all_gen_first = self .active_requests and all (
1807+ req .py_disaggregated_params and req .py_disaggregated_params .
1808+ schedule_style == DisaggScheduleStyle .GENERATION_FIRST
1809+ for req in self .active_requests )
18011810 if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first :
18021811 logger .warning (
18031812 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -2773,12 +2782,7 @@ def _schedule(self):
27732782 scheduled_requests .generation_requests = scheduler_output .generation_requests
27742783 scheduled_requests .paused_requests = scheduler_output .paused_requests
27752784
2776- all_gen_first = self .active_requests and all (
2777- req .py_disaggregated_params and req .py_disaggregated_params .
2778- schedule_style == DisaggScheduleStyle .GENERATION_FIRST
2779- for req in self .active_requests )
2780-
2781- return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests , all_gen_first
2785+ return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
27822786
27832787 @nvtx_range ("_check_disagg_gen_transfer_status" )
27842788 def _check_disagg_gen_transfer_status (self ):
@@ -2839,17 +2843,16 @@ def _check_disagg_ctx_schedulable_status(self,
28392843 """
28402844 if not self .kv_cache_transceiver :
28412845 return
2842- ctx_only_requests = [
2846+ gen_first_ctx_requests = [
28432847 req for req in new_requests
28442848 if req .is_context_only_request and req .py_disaggregated_params .
28452849 schedule_style == DisaggScheduleStyle .GENERATION_FIRST
28462850 ]
28472851 # Always call prepare_context_requests when there are new requests
28482852 # or previously-waiting requests, so the tp_allgather consensus
28492853 # can promote requests whose peer info has arrived on all ranks.
2850- if ctx_only_requests or self .kv_cache_transceiver .wait_req_id_to_request :
2851- self .kv_cache_transceiver .prepare_context_requests (
2852- ctx_only_requests )
2854+ self .kv_cache_transceiver .prepare_context_requests (
2855+ gen_first_ctx_requests )
28532856
28542857 @nvtx_range ("_pad_attention_dp_dummy_request" )
28552858 def _pad_attention_dp_dummy_request (self ):
0 commit comments