@@ -658,6 +658,42 @@ def _run_cuda_graph_warmup(self, resource_manager: ResourceManager):
658658 self ._capture_generation_cuda_graphs (resource_manager )
659659 self ._capture_piecewise_cuda_graphs (resource_manager )
660660
661+ def _graphs_for_dynamic_draft_length (self ):
662+ """
663+ Compute the set of (batch_size, draft_len) pairs that are actually reachable.
664+ Used in dynamic draft length feature.
665+ """
666+ graphs_to_capture = []
667+ schedule_thresholds = sorted (self .spec_config .draft_len_schedule .keys ())
668+
669+ # Only iterate over actual CUDA graph batch sizes, not all possible batch sizes
670+ for graph_bs in self ._cuda_graph_batch_sizes :
671+ idx = bisect .bisect_right (schedule_thresholds , graph_bs )
672+ if idx == 0 :
673+ draft_len = 0 # Defensive
674+ else :
675+ draft_len = self .spec_config .draft_len_schedule [
676+ schedule_thresholds [idx - 1 ]]
677+
678+ graphs_to_capture .append ((graph_bs , draft_len ))
679+
680+ return list (
681+ set (graphs_to_capture )) # Use set to remove duplicates if any
682+
683+ # def _round_up_to_graph_size(self, actual_bs: int) -> int:
684+ # """Round up actual batch size to nearest CUDA graph batch size using binary search."""
685+ # if not self._cuda_graph_batch_sizes:
686+ # return 0
687+
688+ # idx = bisect.bisect_left(self._cuda_graph_batch_sizes, actual_bs)
689+
690+ # # If exact match or idx points to next larger size
691+ # if idx < len(self._cuda_graph_batch_sizes):
692+ # return self._cuda_graph_batch_sizes[idx]
693+
694+ # # actual_bs is larger than all available sizes
695+ # return self._cuda_graph_batch_sizes[-1]
696+
661697 def _capture_generation_cuda_graphs (self ,
662698 resource_manager : ResourceManager ):
663699 """Captures CUDA graphs for pure generation steps."""
@@ -674,38 +710,48 @@ def _capture_generation_cuda_graphs(self,
674710 cuda_graph_batch_sizes = sorted (self ._cuda_graph_batch_sizes ,
675711 reverse = True )
676712 # Create CUDA graphs for different draft lengths
677- draft_lengths = []
713+ # draft_lengths = []
678714 if self .is_draft_model :
679715 if self .model_is_wrapped and self .is_spec_decode and spec_resource_manager is not None and isinstance (
680716 spec_resource_manager , Eagle3ResourceManager ):
681717 # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
682- draft_lengths . append ( self .original_max_total_draft_tokens )
718+ draft_len = self .original_max_total_draft_tokens
683719 else :
684- draft_lengths .append (self .max_total_draft_tokens )
720+ draft_len = self .max_total_draft_tokens
721+ graphs_to_capture = [(bs , draft_len )
722+ for bs in cuda_graph_batch_sizes ]
723+ elif (self .spec_config
724+ and hasattr (self .spec_config , 'draft_len_schedule' )
725+ and self .spec_config .draft_len_schedule is not None ):
726+ # target model with draft_len_schedule: compute exact reachable set
727+ graphs_to_capture = self ._graphs_for_dynamic_draft_length ()
685728 else :
686729 # For non-draft model, we also capture the CUDA graph instance for draft length 0,
687730 # so that when we disable spec decode at runtime, we can still run the captured graph.
688731 # Note that for one engine mode, we are not able to turn off spec decode at runtime.
732+ graphs_to_capture = []
689733 if (self .max_total_draft_tokens > 0
690734 and not self .spec_config .spec_dec_mode .use_one_engine ()
691735 # Assume that speculation is always on if the user didn't give us a max_concurrency
692736 # value. This will save on memory.
693737 and self .spec_config .max_concurrency is not None ):
694- draft_lengths .append (0 )
695- draft_lengths = [self .max_total_draft_tokens ]
738+ graphs_to_capture .extend ([(bs , 0 )
739+ for bs in cuda_graph_batch_sizes ])
740+ else :
741+ graphs_to_capture .extend ([(bs , self .max_total_draft_tokens )
742+ for bs in cuda_graph_batch_sizes ])
696743
697- for bs in cuda_graph_batch_sizes :
744+ graphs_to_capture = sorted (graphs_to_capture , reverse = True )
745+ for bs , draft_len in graphs_to_capture :
698746 if bs > self .batch_size :
699747 continue
700-
701- for draft_len in draft_lengths :
702- warmup_request = self ._create_cuda_graph_warmup_request (
703- resource_manager , bs , draft_len )
704- with self ._release_batch_context (warmup_request ,
705- resource_manager ) as batch :
706- if batch is None :
707- # No KV cache space, cannot continue capturing graphs
708- return
748+ warmup_request = self ._create_cuda_graph_warmup_request (
749+ resource_manager , bs , draft_len )
750+ with self ._release_batch_context (warmup_request ,
751+ resource_manager ) as batch :
752+ if batch is None :
753+ # No KV cache space, cannot continue capturing graphs
754+ return
709755
710756 logger .info (
711757 f"Run generation-only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
0 commit comments