1313import types
1414from collections import defaultdict
1515from dataclasses import dataclass
16- from types import SimpleNamespace
16+ from types import MethodType , SimpleNamespace
1717from typing import Dict , List , Optional , Tuple
1818
1919import torch
20+ import torch .nn .functional as F
2021from strenum import StrEnum
2122from torch ._prims_common import DeviceLikeType
2223
2324from tensorrt_llm ._torch .attention_backend .interface import AttentionRuntimeFeatures
25+ from tensorrt_llm ._torch .auto_deploy .utils ._graph import get_input_embeddings , get_lm_head_weights
26+ from tensorrt_llm ._torch .models .modeling_speculative import Eagle3ForCausalLM
2427from tensorrt_llm ._torch .pyexecutor ._util import (
2528 _create_kv_cache_manager ,
2629 get_decoding_mode ,
3235from tensorrt_llm ._torch .pyexecutor .py_executor_creator import get_guided_decoding_config
3336from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
3437from tensorrt_llm ._torch .speculative import get_spec_drafter
38+ from tensorrt_llm ._torch .speculative .eagle3 import Eagle3ResourceManager
3539from tensorrt_llm ._utils import nvtx_range
3640from tensorrt_llm .llmapi .llm_args import (
3741 ContextChunkingPolicy ,
42+ EagleDecodingConfig ,
3843 LoadFormat ,
3944 SamplerType ,
4045 TorchLlmArgs ,
5762from ...pyexecutor .scheduler import (
5863 BindCapacityScheduler ,
5964 BindMicroBatchScheduler ,
65+ RequestList ,
6066 ScheduledRequests ,
6167 SimpleScheduler ,
6268)
@@ -113,6 +119,90 @@ def calculate_max_num_blocks(
113119 return self .num_blocks , 0
114120
115121
122+ class ADHiddenStateManager (Eagle3ResourceManager ):
123+ def __init__ (
124+ self ,
125+ cache_seq_interface : CachedSequenceInterface ,
126+ config : EagleDecodingConfig ,
127+ max_num_requests : int ,
128+ max_seq_len : int ,
129+ max_num_tokens : int ,
130+ ):
131+ hidden_state_buffer = self ._get_hidden_state_buffers (cache_seq_interface )[0 ]
132+ dtype = hidden_state_buffer .dtype
133+ hidden_size = hidden_state_buffer .shape [1 ]
134+
135+ super ().__init__ (config , dtype , hidden_size , max_num_requests , max_seq_len , max_num_tokens )
136+
137+ self .hidden_state_write_indices : torch .Tensor = torch .empty (
138+ max_num_tokens , dtype = torch .long , device = "cuda"
139+ )
140+
141+ def _get_hidden_state_buffers (
142+ self , cache_seq_interface : CachedSequenceInterface
143+ ) -> List [torch .Tensor ]:
144+ hidden_state_buffers = []
145+ for name , tensor in cache_seq_interface .named_args .items ():
146+ if "hidden_states_cache" in name :
147+ hidden_state_buffers .append (tensor )
148+
149+ if not hidden_state_buffers :
150+ raise ValueError (
151+ "No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
152+ )
153+ return hidden_state_buffers
154+
155+ def prepare_hidden_states_capture (
156+ self , ordered_requests : RequestList , cache_seq_interface : CachedSequenceInterface
157+ ) -> None :
158+ """Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
159+ seq_lens = cache_seq_interface .info .seq_len
160+ num_tokens = sum (seq_lens )
161+
162+ start_idx = 0
163+ hidden_states_write_indices = []
164+ for request , seq_len in zip (ordered_requests , seq_lens ):
165+ request_id = request .request_id
166+ slot_id = self .slot_manager .get_slot (request_id )
167+ self .start_indices [slot_id ] = start_idx
168+ hidden_states_write_indices .extend (range (start_idx , start_idx + seq_len ))
169+ start_idx += max (seq_len , self .max_total_draft_tokens + 1 )
170+ assert start_idx < self .hidden_states .shape [0 ], (
171+ f"start_idx { start_idx } exceeds hidden_states capacity { self .hidden_states .shape [0 ]} "
172+ )
173+
174+ if len (hidden_states_write_indices ) != num_tokens :
175+ raise ValueError (
176+ f"len(hidden_state_write_indices) ({ len (hidden_states_write_indices )} ) != num_tokens \
177+ ({ num_tokens } ). Check whether ordered_requests matches up with seq_lens."
178+ )
179+
180+ hidden_state_write_indices_host = torch .tensor (
181+ hidden_states_write_indices , dtype = torch .long
182+ )
183+
184+ self .hidden_state_write_indices [:num_tokens ].copy_ (
185+ hidden_state_write_indices_host , non_blocking = True
186+ )
187+
188+ def capture_hidden_states (self , cache_seq_interface : CachedSequenceInterface ) -> None :
189+ """Capture configured hidden states that have been written by the model,
190+ in a format that can be used by the draft model.
191+ """
192+ full_hidden_states = self ._get_hidden_state_buffers (cache_seq_interface )
193+ if not full_hidden_states :
194+ return
195+
196+ num_tokens = sum (cache_seq_interface .info .seq_len )
197+
198+ hidden_states = [hidden_state [:num_tokens ] for hidden_state in full_hidden_states ]
199+ hidden_states = torch .cat (hidden_states , dim = 1 )
200+ hidden_states = hidden_states .to (dtype = self .dtype )
201+
202+ token_idx = self .hidden_state_write_indices [:num_tokens ]
203+ self .hidden_states [:, : hidden_states .shape [1 ]].index_copy_ (0 , token_idx , hidden_states )
204+
205+
116206def construct_draft_llm_args (
117207 ad_config : LlmArgs ,
118208) -> TorchLlmArgs :
@@ -461,6 +551,10 @@ def _prepare_inputs(
461551 kv_cache_manager = resource_manager .get_resource_manager (
462552 ResourceManagerType .KV_CACHE_MANAGER
463553 )
554+ # resource manager for hidden state capture
555+ spec_resource_manager = resource_manager .get_resource_manager (
556+ ResourceManagerType .SPEC_RESOURCE_MANAGER
557+ )
464558
465559 # requests in order of context, generate
466560 context_requests = scheduled_requests .context_requests
@@ -471,6 +565,7 @@ def _prepare_inputs(
471565 r for r in scheduled_requests .generation_requests if get_draft_token_length (r ) == 0
472566 ]
473567 gen_requests = extend_requests + generation_requests
568+ ordered_requests = context_requests + gen_requests
474569 # info to be extracted
475570 input_ids : List [List [int ]] = []
476571 position_ids : List [List [int ]] = []
@@ -670,6 +765,13 @@ def _build_input_ids(request) -> Tuple[List[int], List[int], bool]:
670765
671766 self .cache_seq_interface .info .run_host_prepare_for_attention_forward ()
672767
768+ if spec_resource_manager is not None and isinstance (
769+ spec_resource_manager , ADHiddenStateManager
770+ ):
771+ spec_resource_manager .prepare_hidden_states_capture (
772+ ordered_requests , self .cache_seq_interface
773+ )
774+
673775 self .iter_states ["num_ctx_requests" ] = num_ctx_requests
674776 self .iter_states ["num_ctx_tokens" ] = num_ctx_tokens
675777 # TODO: handle extend requests and draft requests for specdec
@@ -710,14 +812,74 @@ def forward(
710812 outputs = {
711813 "logits" : self ._compute_logits (),
712814 }
815+
816+ # save hidden states after running model.forward() in _compute_logits()
817+ spec_resource_manager = resource_manager .get_resource_manager (
818+ ResourceManagerType .SPEC_RESOURCE_MANAGER
819+ )
820+ if spec_resource_manager is not None and isinstance (
821+ spec_resource_manager , ADHiddenStateManager
822+ ):
823+ spec_resource_manager .capture_hidden_states (self .cache_seq_interface )
824+
713825 if self .mapping is not None :
714826 self ._execute_logit_post_processors (scheduled_requests , outputs )
715827
716828 return outputs
717829
718830
831+ def share_target_weights_with_draft (
832+ target_model_engine : "ADEngine" , draft_model_engine : PyTorchModelEngine
833+ ):
834+ """
835+ Certain speculative decoding methods (e.g. Eagle3) require sharing the target model's embedding and lm_head weights
836+ with the draft model. This function does this sharing if necessary.
837+ """
838+
839+ assert isinstance (draft_model_engine .model , Eagle3ForCausalLM ), (
840+ f"Expected draft_model_engine.model to be Eagle3ForCausalLM, got { type (draft_model_engine .model )} "
841+ )
842+
843+ def share_embedding_weights_with_draft (
844+ target_model_engine : "ADEngine" , draft_model_engine : PyTorchModelEngine
845+ ):
846+ embedding_weight = get_input_embeddings (target_model_engine .model )
847+
848+ world_size = mpi_world_size ()
849+ assert world_size <= 1 , f"This code assumes tp<=1. World size: { world_size } "
850+
851+ # Note: This simple forward function implementation assumes tp=1.
852+ # TODO(govind): Handle the tp>1 case.
853+ def new_embedding_forward (self , input_ids ):
854+ return F .embedding (input_ids , self .weight )
855+
856+ if draft_model_engine .model .model .embed_tokens is None :
857+ submodule = torch .nn .Module ()
858+ submodule .forward = MethodType (new_embedding_forward , submodule )
859+ submodule .weight = embedding_weight
860+ draft_model_engine .model .model .embed_tokens = submodule
861+
862+ def share_lm_head_weights_with_draft (
863+ target_model_engine : "ADEngine" , draft_model_engine : PyTorchModelEngine
864+ ):
865+ vocab_size = target_model_engine .cache_seq_interface .info .vocab_size_padded
866+
867+ lm_head_weight = get_lm_head_weights (target_model_engine .model )
868+
869+ assert lm_head_weight .shape [0 ] == vocab_size , (
870+ f"Expected lm_head weight first dimension to be vocab_size={ vocab_size } , "
871+ f"but got shape { lm_head_weight .shape } "
872+ )
873+
874+ if draft_model_engine .model .load_lm_head_from_target :
875+ draft_model_engine .model .lm_head .weight = lm_head_weight
876+
877+ share_embedding_weights_with_draft (target_model_engine , draft_model_engine )
878+ share_lm_head_weights_with_draft (target_model_engine , draft_model_engine )
879+
880+
719881def create_draft_model_engine_maybe (
720- ad_config : LlmArgs , engine , dist_mapping : Mapping , mpi_dist : MPIDist
882+ ad_config : LlmArgs , target_engine : ADEngine , dist_mapping : Mapping , mpi_dist : MPIDist
721883) -> Optional [PyTorchModelEngine ]:
722884 """Create a draft model engine for speculative decoding.
723885
@@ -745,14 +907,18 @@ def create_draft_model_engine_maybe(
745907 chunked_prefill = ad_config .enable_chunked_prefill ,
746908 cache_reuse = kv_cache_config .enable_block_reuse ,
747909 has_speculative_draft_tokens = has_spec_drafter ,
748- chunk_size = engine .llm_args .max_num_tokens ,
910+ chunk_size = target_engine .llm_args .max_num_tokens ,
749911 )
750912
751913 # Construct TorchLlmArgs for the draft model
752914 draft_llm_args = construct_draft_llm_args (
753915 ad_config = ad_config ,
754916 )
755917
918+ # chain drafter is not supported currently for AutoDeploy.
919+ # TODO(govind): Do this when we want to optimize 2-model spec dec performance.
920+ drafting_loop_wrapper = None
921+
756922 draft_model_engine = PyTorchModelEngine (
757923 model_path = draft_spec_config .speculative_model_dir ,
758924 llm_args = draft_llm_args ,
@@ -761,9 +927,14 @@ def create_draft_model_engine_maybe(
761927 dist = mpi_dist ,
762928 spec_config = draft_spec_config ,
763929 is_draft_model = True ,
764- drafting_loop_wrapper = None ,
930+ drafting_loop_wrapper = drafting_loop_wrapper ,
765931 )
766932
933+ if draft_spec_config .spec_dec_mode .is_eagle3 ():
934+ share_target_weights_with_draft (
935+ target_model_engine = target_engine , draft_model_engine = draft_model_engine
936+ )
937+
767938 draft_model_engine .kv_cache_manager_key = ResourceManagerType .DRAFT_KV_CACHE_MANAGER
768939
769940 return draft_model_engine
@@ -855,21 +1026,32 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
8551026 engine = ADEngine .build_from_config (ad_config = ad_config , mapping = dist_mapping )
8561027
8571028 spec_config = ad_config .speculative_config
858- if spec_config is not None and not spec_config .spec_dec_mode .is_draft_target ():
1029+ if spec_config is not None and not (
1030+ spec_config .spec_dec_mode .is_draft_target () or spec_config .spec_dec_mode .is_eagle3 ()
1031+ ):
8591032 raise ValueError (
860- "Currently, AutoDeploy only supports speculative decoding in draft target mode."
1033+ "Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
8611034 )
8621035
8631036 if spec_config is not None and ad_config .guided_decoding_backend is not None :
8641037 raise ValueError (
8651038 "Guided decoding is not currently supported for speculative decoding in AutoDeploy."
8661039 )
8671040
868- # Speculative resource manager not needed for DraftTargetDecoding.
869- spec_resource_manager = None
870-
8711041 draft_model_engine = create_draft_model_engine_maybe (
872- ad_config = ad_config , engine = engine , dist_mapping = dist_mapping , mpi_dist = mpi_dist
1042+ ad_config = ad_config , target_engine = engine , dist_mapping = dist_mapping , mpi_dist = mpi_dist
1043+ )
1044+
1045+ spec_resource_manager = (
1046+ ADHiddenStateManager (
1047+ cache_seq_interface = engine .cache_seq_interface ,
1048+ config = spec_config ,
1049+ max_num_requests = ad_config .max_batch_size ,
1050+ max_seq_len = engine .llm_args .max_seq_len ,
1051+ max_num_tokens = engine .llm_args .max_num_tokens ,
1052+ )
1053+ if isinstance (spec_config , EagleDecodingConfig )
1054+ else None
8731055 )
8741056
8751057 # check kvcache config for partial block reuse
0 commit comments