@@ -1152,7 +1152,7 @@ def _pp_schedule_and_propagate(self, microbatch_id: int):
11521152 is_dp_broadcast = self .dist .tp_size > 1 and self .enable_attention_dp
11531153 if self .dist .rank == 0 or (self .dist .is_first_pp_rank
11541154 and is_dp_broadcast ):
1155- scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs , all_gen_first = self ._schedule (
1155+ scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
11561156 )
11571157 serializable_schedule = SerializableSchedulerOutput .from_scheduler_result (
11581158 scheduled_batch , fitting_disagg_gen_init_requests ,
@@ -1280,6 +1280,11 @@ def _executor_loop_pp(self):
12801280 self ._prepare_disagg_gen_init (
12811281 fitting_disagg_gen_init_requests )
12821282
1283+ all_gen_first = self .active_requests and all (
1284+ req .py_disaggregated_params
1285+ and req .py_disaggregated_params .schedule_style ==
1286+ DisaggScheduleStyle .GENERATION_FIRST
1287+ for req in self .active_requests )
12831288 if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first :
12841289 logger .warning (
12851290 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -1584,7 +1589,7 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
15841589 # _handle_responses sees the request before it is terminated.
15851590 if self .kv_cache_transceiver :
15861591 self ._check_disagg_ctx_cache_transfer_status (0 )
1587- sample_state_scheduled_requests = executed_batch .sample_state . scheduled_requests
1592+ sample_state_scheduled_requests = executed_batch .scheduled_requests
15881593 attn_metadata = getattr (self .model_engine , 'attn_metadata' ,
15891594 None )
15901595 kv_cache_dtype_byte_size = getattr (self .model_engine ,
@@ -1766,7 +1771,7 @@ def _prepare_and_schedule_batch(self):
17661771 # that speculation is about to happen.
17671772 self ._prepare_draft_requests ()
17681773
1769- scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs , all_gen_first = self ._schedule (
1774+ scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
17701775 )
17711776
17721777 if self .drafter is not None and not self .use_spec_decode :
@@ -1777,6 +1782,10 @@ def _prepare_and_schedule_batch(self):
17771782 # For requests that are fitting disagg gen init, also prepare resources for KV cache manager
17781783 self ._prepare_disagg_gen_init (fitting_disagg_gen_init_requests )
17791784
1785+ all_gen_first = self .active_requests and all (
1786+ req .py_disaggregated_params and req .py_disaggregated_params .
1787+ schedule_style == DisaggScheduleStyle .GENERATION_FIRST
1788+ for req in self .active_requests )
17801789 if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests and not all_gen_first :
17811790 logger .warning (
17821791 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
@@ -2719,12 +2728,7 @@ def _schedule(self):
27192728 scheduled_requests .generation_requests = scheduler_output .generation_requests
27202729 scheduled_requests .paused_requests = scheduler_output .paused_requests
27212730
2722- all_gen_first = self .active_requests and all (
2723- req .py_disaggregated_params and req .py_disaggregated_params .
2724- schedule_style == DisaggScheduleStyle .GENERATION_FIRST
2725- for req in self .active_requests )
2726-
2727- return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests , all_gen_first
2731+ return scheduled_requests , scheduler_output .fitting_disagg_gen_init_requests , scheduler_output .num_fitting_requests
27282732
27292733 @nvtx_range ("_check_disagg_gen_transfer_status" )
27302734 def _check_disagg_gen_transfer_status (self ):
@@ -2785,17 +2789,16 @@ def _check_disagg_ctx_schedulable_status(self,
27852789 """
27862790 if not self .kv_cache_transceiver :
27872791 return
2788- ctx_only_requests = [
2792+ gen_first_ctx_requests = [
27892793 req for req in new_requests
27902794 if req .is_context_only_request and req .py_disaggregated_params .
27912795 schedule_style == DisaggScheduleStyle .GENERATION_FIRST
27922796 ]
27932797 # Always call prepare_context_requests when there are new requests
27942798 # or previously-waiting requests, so the tp_allgather consensus
27952799 # can promote requests whose peer info has arrived on all ranks.
2796- if ctx_only_requests or self .kv_cache_transceiver .wait_req_id_to_request :
2797- self .kv_cache_transceiver .prepare_context_requests (
2798- ctx_only_requests )
2800+ self .kv_cache_transceiver .prepare_context_requests (
2801+ gen_first_ctx_requests )
27992802
28002803 @nvtx_range ("_pad_attention_dp_dummy_request" )
28012804 def _pad_attention_dp_dummy_request (self ):
0 commit comments