@@ -77,7 +77,6 @@ def forward(self,
7777 scheduled_requests : ScheduledRequests ,
7878 resource_manager : ResourceManager ,
7979 new_tensors_device : Optional [SampleStateTensors ],
80- extra_model_inputs : Optional [Dict [str , Any ]],
8180 gather_context_logits : bool = False ):
8281 raise NotImplementedError
8382
@@ -338,6 +337,7 @@ def __init__(
338337 spec_config : Optional [SpecConfig ] = None ,
339338 guided_decoding_config : Optional [GuidedDecodingConfig ] = None ,
340339 lora_config : Optional [LoraConfig ] = None ,
340+ is_draft_model : bool = False ,
341341 ):
342342 self .ub_buffers = None
343343 self .batch_size = batch_size
@@ -353,10 +353,7 @@ def __init__(
353353 self .pytorch_backend_config = pytorch_backend_config
354354 self .spec_config = spec_config
355355 self .is_spec_decode = spec_config is not None
356- # We keep a reference to the last used spec metadata to
357- # accommodate certain target/draft model use cases. See
358- # py_executor.py for how this is used.
359- self .last_spec_metadata = None
356+ self .is_draft_model = is_draft_model
360357
361358 self .in_warmup = False
362359
@@ -530,6 +527,15 @@ def wrapper(self, *args, **kwargs):
530527
531528 return wrapper
532529
530+ @contextlib .contextmanager
531+ def no_cuda_graph (self ):
532+ _run_cuda_graphs = self ._run_cuda_graphs
533+ self ._run_cuda_graphs = False
534+ try :
535+ yield
536+ finally :
537+ self ._run_cuda_graphs = _run_cuda_graphs
538+
533539 @with_warmup_flag
534540 def warmup (self , resource_manager : ResourceManager ) -> None :
535541 kv_cache_manager = resource_manager .get_resource_manager (
@@ -654,7 +660,7 @@ def get_autotune_warmup_request():
654660 result .context_requests = requests
655661 result .generation_requests = []
656662
657- return result , _create_extra_inputs ( 1 , maximum_tunable_num_tokens )
663+ return result
658664
659665 @contextlib .contextmanager
660666 def release_batch (result ):
@@ -668,29 +674,6 @@ def release_batch(result):
668674 if spec_resource_manager is not None :
669675 spec_resource_manager .free_resources (req )
670676
671- @contextlib .contextmanager
672- def no_cuda_graph ():
673- _run_cuda_graphs = self ._run_cuda_graphs
674- self ._run_cuda_graphs = False
675- try :
676- yield
677- finally :
678- self ._run_cuda_graphs = _run_cuda_graphs
679-
680- def _create_extra_inputs (bs , num_tokens_per_request ):
681- if self .spec_config is None :
682- extra_model_inputs = None
683- else :
684- warmup_inputs_creator = getattr (self .model ,
685- "get_warmup_extra_inputs" , None )
686- if callable (warmup_inputs_creator ):
687- extra_model_inputs = warmup_inputs_creator (
688- bs , num_tokens_per_request )
689- else :
690- extra_model_inputs = None
691-
692- return extra_model_inputs
693-
694677 # TODO: current warmup_request is not suitable for star attention
695678 cp_type = self .mapping .cp_config .get ('cp_type' , None )
696679 if cp_type == 'star_attention' :
@@ -712,7 +695,7 @@ def disable_optimization(backend: Backend):
712695 set_enable_piecewise_cuda_graph_capture_flag (True )
713696
714697 # Disable cuda graph capture here so that we can properly capture it later
715- with no_cuda_graph ():
698+ with self . no_cuda_graph ():
716699 available_tokens = kv_cache_manager .get_num_available_tokens (
717700 self .max_draft_len )
718701 warmup_batch_size = [1 , self .batch_size // 2 ]
@@ -733,17 +716,14 @@ def disable_optimization(backend: Backend):
733716 logger .info (
734717 f"Run warmup for batch size={ bs } , pure { 'context' if num_tokens_per_request > 1 else 'generation' } phase"
735718 )
736- self .forward (
737- batch ,
738- new_tensors_device = None ,
739- resource_manager = resource_manager ,
740- extra_model_inputs = _create_extra_inputs (
741- bs , num_tokens_per_request ))
719+ self .forward (batch ,
720+ new_tensors_device = None ,
721+ resource_manager = resource_manager )
742722 torch .cuda .synchronize ()
743723
744724 if self .pytorch_backend_config .autotuner_enabled :
745- with no_cuda_graph (), autotune ():
746- result , extra_model_inputs = get_autotune_warmup_request ()
725+ with self . no_cuda_graph (), autotune ():
726+ result = get_autotune_warmup_request ()
747727 with release_batch (result ) as batch :
748728 if batch is None :
749729 # No KV cache space!
@@ -753,8 +733,7 @@ def disable_optimization(backend: Backend):
753733 f"Run autotuning warmup for batch size={ 1 } " )
754734 self .forward (batch ,
755735 new_tensors_device = None ,
756- resource_manager = resource_manager ,
757- extra_model_inputs = extra_model_inputs )
736+ resource_manager = resource_manager )
758737 torch .cuda .synchronize ()
759738
760739 logger .info (f"Autotuner Cache size after warmup " +
@@ -783,12 +762,11 @@ def disable_optimization(backend: Backend):
783762 )
784763 self .forward (batch ,
785764 new_tensors_device = None ,
786- resource_manager = resource_manager ,
787- extra_model_inputs = _create_extra_inputs (bs , 1 ))
765+ resource_manager = resource_manager )
788766 torch .cuda .synchronize ()
789767
790768 if self ._torch_compile_piecewise_cuda_graph :
791- with no_cuda_graph ():
769+ with self . no_cuda_graph ():
792770 with release_batch (
793771 get_torch_compile_warmup_request (1 ,
794772 bs )) as batch :
@@ -797,17 +775,12 @@ def disable_optimization(backend: Backend):
797775 )
798776
799777 for _ in range (3 ):
800- self .forward (
801- batch ,
802- new_tensors_device = None ,
803- resource_manager = resource_manager ,
804- extra_model_inputs = _create_extra_inputs (
805- 1 , bs ))
806- self .forward (
807- batch ,
808- new_tensors_device = None ,
809- resource_manager = resource_manager ,
810- extra_model_inputs = _create_extra_inputs (1 , bs ))
778+ self .forward (batch ,
779+ new_tensors_device = None ,
780+ resource_manager = resource_manager )
781+ self .forward (batch ,
782+ new_tensors_device = None ,
783+ resource_manager = resource_manager )
811784 torch .cuda .synchronize ()
812785 gc .collect ()
813786 torch .cuda .empty_cache ()
@@ -851,15 +824,17 @@ def _set_up_spec_metadata(
851824 self .spec_config ,
852825 self .batch_size ,
853826 max_num_tokens = self .max_num_tokens ,
854- spec_resource_manager = spec_resource_manager )
827+ spec_resource_manager = spec_resource_manager ,
828+ is_draft_model = self .is_draft_model )
855829
856830 if self .spec_metadata is not None :
857831 return self .spec_metadata
858832 self .spec_metadata = get_spec_metadata (
859833 self .spec_config ,
860834 self .batch_size ,
861835 max_num_tokens = self .max_num_tokens ,
862- spec_resource_manager = spec_resource_manager )
836+ spec_resource_manager = spec_resource_manager ,
837+ is_draft_model = self .is_draft_model )
863838 return self .spec_metadata
864839
865840 def _get_padded_batch (self , scheduled_requests : ScheduledRequests ,
@@ -1155,7 +1130,6 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11551130 inputs ['attn_metadata' ].kv_lens_cuda [
11561131 num_ctx_requests :num_seqs ] += (
11571132 self .previous_kv_lens_offsets_cuda [:num_gen_requests ])
1158-
11591133 return inputs
11601134
11611135 def _prepare_tp_inputs (
@@ -1476,6 +1450,7 @@ def _prepare_tp_inputs(
14761450 lora_params = self ._get_lora_params_from_requests (
14771451 scheduled_requests , attn_metadata )
14781452
1453+ # Prepare inputs
14791454 inputs = {
14801455 'attn_metadata' : attn_metadata ,
14811456 'input_ids' : self .input_ids_cuda [:total_num_tokens ],
@@ -2027,7 +2002,6 @@ def forward(self,
20272002 scheduled_requests : ScheduledRequests ,
20282003 resource_manager : ResourceManager ,
20292004 new_tensors_device : Optional [SampleStateTensors ] = None ,
2030- extra_model_inputs : Optional [Dict [str , Any ]] = None ,
20312005 gather_context_logits : bool = False ):
20322006
20332007 kv_cache_manager = resource_manager .get_resource_manager (
@@ -2055,9 +2029,6 @@ def forward(self,
20552029 if kv_cache_manager is None :
20562030 inputs , gather_ids = self ._prepare_tp_inputs_no_cache (
20572031 scheduled_requests , attn_metadata , spec_metadata )
2058- if extra_model_inputs is not None :
2059- inputs .update (extra_model_inputs )
2060- self .last_spec_metadata = spec_metadata
20612032
20622033 with MoeLoadBalancerIterContext (moe_load_balancer ):
20632034 return self ._forward_step (inputs , gather_ids ,
@@ -2081,9 +2052,6 @@ def forward(self,
20812052 attn_metadata ,
20822053 spec_metadata ,
20832054 new_tensors_device )
2084- if extra_model_inputs is not None :
2085- inputs .update (extra_model_inputs )
2086- self .last_spec_metadata = spec_metadata
20872055
20882056 self .iter_counter += 1
20892057
@@ -2104,16 +2072,15 @@ def capture_forward_fn(inputs: Dict[str, Any]):
21042072 pool = maybe_graph .capture (
21052073 capture_forward_fn ,
21062074 self ._cuda_graph_mem_pool ,
2107- extra_model_inputs ,
21082075 )
21092076 self ._cuda_graph_mem_pool = pool
21102077
21112078 # here we don't need to use context since cuda graph capture didn't run kernel.
21122079 # maybe we need a cleaner way to do this.
2113- outputs = maybe_graph .run (inputs , extra_model_inputs )
2080+ outputs = maybe_graph .run (inputs )
21142081 else :
21152082 with MoeLoadBalancerIterContext (moe_load_balancer ):
2116- outputs = maybe_graph .run (inputs , extra_model_inputs )
2083+ outputs = maybe_graph .run (inputs )
21172084
21182085 # Note: To overlap the CPU and GPU computation as much as possible,
21192086 # guided_decoder.build should be called immediately after the launch of the single step;
@@ -2148,6 +2115,8 @@ def _forward_step(self,
21482115 gather_ids : Optional [torch .Tensor ],
21492116 gather_context_logits : bool = False ) -> Dict [str , Any ]:
21502117 inputs = self ._preprocess_inputs (inputs )
2118+ if inputs .get ('spec_metadata' , None ):
2119+ gather_ids = inputs ['spec_metadata' ].gather_ids
21512120 if self .without_logits :
21522121 outputs = self .model_forward (** inputs )
21532122 return outputs
0 commit comments