@@ -1567,6 +1567,24 @@ def _executor_loop_overlap(self):
15671567 # For generation requests which have completed KV cache transfer
15681568 self ._prepare_disagg_gen_transmission_complete (
15691569 scheduled_batch )
1570+
1571+ has_draft_batch = self .drafter is not None and self .previous_batch is not None and self .use_spec_decode and self .drafter .should_forward_draft_model (
1572+ scheduled_batch )
1573+ # Reset the draft tokens to avoid preparing resources for the draft model.
1574+ if self .drafter is not None and self .use_spec_decode and not has_draft_batch :
1575+ self .use_spec_decode = False
1576+ # We are not running the draft model. Remove the draft tokens and turn off spec
1577+ # decode so that the requests get handled correctly.
1578+ # One corner case: when we have at least one context request, we have to keep spec
1579+ # dec on. This ensures that we capture hidden states for requests that haven't done
1580+ # prefill yet.
1581+ self .use_spec_decode = False
1582+ self .model_engine .enable_spec_decode = len (
1583+ scheduled_batch .context_requests ) > 0
1584+ if not self .model_engine .enable_spec_decode :
1585+ for request in scheduled_batch .all_requests ():
1586+ request .py_draft_tokens = []
1587+
15701588 self .resource_manager .prepare_resources (scheduled_batch )
15711589
15721590 self ._kv_connector_start_batch (scheduled_batch )
@@ -1602,8 +1620,11 @@ def _executor_loop_overlap(self):
16021620 # so we'll set the target model's input to None and skip updating the target requests after target model forward.
16031621 use_previous_draft_tokens = self .has_previous_draft_tokens
16041622 num_accepted_tokens_device = None
1605- if self .drafter is not None and (self .use_spec_decode or
1606- use_previous_draft_tokens ):
1623+
1624+ target_inputs = None
1625+ num_accepted_tokens_device = None
1626+
1627+ if has_draft_batch :
16071628 target_inputs , num_accepted_tokens_device = self ._handle_speculative_decoding (
16081629 scheduled_batch , previous_tensors ,
16091630 previous_tensors_device )
@@ -2746,44 +2767,20 @@ def _handle_speculative_decoding(
27462767 ) -> Tuple [Optional [SampleStateTensorsMTP ], Optional [torch .Tensor ]]:
27472768 with request_context (is_draft = self .draft_model_engine is not None ,
27482769 scheduled_requests = scheduled_batch ):
2749- # Do an early checking to see if we need to forward the draft model.
2750- # If needed, the overlap should happen between the target requests and the draft requests.
2751- # Otherwise, we can still do overlap between the previous target requests and the current target requests.
2752- has_draft_batch = (
2753- self .previous_batch is not None and self .use_spec_decode
2754- and self .drafter .should_forward_draft_model (scheduled_batch ))
2755-
2756- new_target_inputs = None
2757- num_accepted_tokens_device = None
2758- if has_draft_batch :
2759- target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2760- assert target_outputs is not None , "target_outputs should not be None"
2761- new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2762- scheduled_batch = scheduled_batch ,
2763- target_inputs = target_inputs ,
2764- target_outputs = target_outputs )
2765-
2766- if has_draft_batch :
2767- self .drafter .generate_draft_tokens_with_overlap (
2768- scheduled_batch , self .resource_manager ,
2769- previous_tensors .device if previous_tensors else None ,
2770- new_target_inputs , num_accepted_tokens_device )
2771-
2772- # Pad draft tokens to the max draft length for CUDA graph compatibility
2773- self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
2774- else :
2775- self .has_previous_draft_tokens = False
2776- # We are not running the draft model. Remove the draft tokens and turn off spec
2777- # decode so that the requests get handled correctly.
2778- # One corner case: when we have at least one context request, we have to keep spec
2779- # dec on. This ensures that we capture hidden states for requests that haven't done
2780- # prefill yet.
2781- self .use_spec_decode = False
2782- self .model_engine .enable_spec_decode = len (
2783- scheduled_batch .context_requests ) > 0
2784- if not self .model_engine .enable_spec_decode :
2785- for request in scheduled_batch .all_requests ():
2786- request .py_draft_tokens = []
2770+ target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2771+ assert target_outputs is not None , "target_outputs should not be None"
2772+ new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2773+ scheduled_batch = scheduled_batch ,
2774+ target_inputs = target_inputs ,
2775+ target_outputs = target_outputs )
2776+
2777+ self .drafter .generate_draft_tokens_with_overlap (
2778+ scheduled_batch , self .resource_manager ,
2779+ previous_tensors .device if previous_tensors else None ,
2780+ new_target_inputs , num_accepted_tokens_device )
2781+
2782+ # Pad draft tokens to the max draft length for CUDA graph compatibility
2783+ self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
27872784
27882785 return new_target_inputs , num_accepted_tokens_device
27892786
0 commit comments