1212import copy
1313from collections import defaultdict
1414from dataclasses import dataclass
15- from types import SimpleNamespace
15+ from types import MethodType , SimpleNamespace
1616from typing import Dict , List , Optional , Tuple
1717
1818import torch
19+ import torch .nn .functional as F
1920from strenum import StrEnum
2021from torch ._prims_common import DeviceLikeType
2122
3031from tensorrt_llm ._torch .pyexecutor .py_executor_creator import get_guided_decoding_config
3132from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
3233from tensorrt_llm ._torch .speculative import get_spec_drafter
34+ from tensorrt_llm ._torch .speculative .eagle3 import Eagle3ResourceManager
3335from tensorrt_llm ._utils import nvtx_range
3436from tensorrt_llm .llmapi .llm_args import (
3537 ContextChunkingPolicy ,
38+ EagleDecodingConfig ,
3639 LoadFormat ,
3740 SamplerType ,
3841 SpeculativeConfig ,
5154from ...pyexecutor .scheduler import (
5255 BindCapacityScheduler ,
5356 BindMicroBatchScheduler ,
57+ RequestList ,
5458 ScheduledRequests ,
5559 SimpleScheduler ,
5660)
@@ -107,6 +111,90 @@ def calculate_max_num_blocks(
107111 return self .num_blocks , 0
108112
109113
114+ class ADHiddenStateManager (Eagle3ResourceManager ):
115+ def __init__ (
116+ self ,
117+ cache_seq_interface : CachedSequenceInterface ,
118+ config : EagleDecodingConfig ,
119+ max_num_requests : int ,
120+ max_seq_len : int ,
121+ max_num_tokens : int ,
122+ ):
123+ hidden_state_buffer = self ._get_hidden_state_buffers (cache_seq_interface )[0 ]
124+ dtype = hidden_state_buffer .dtype
125+ hidden_size = hidden_state_buffer .shape [1 ]
126+
127+ super ().__init__ (config , dtype , hidden_size , max_num_requests , max_seq_len , max_num_tokens )
128+
129+ self .hidden_state_write_indices : torch .Tensor = torch .empty (
130+ max_num_tokens , dtype = torch .long , device = "cuda"
131+ )
132+
133+ def _get_hidden_state_buffers (
134+ self , cache_seq_interface : CachedSequenceInterface
135+ ) -> List [torch .Tensor ]:
136+ hidden_state_buffers = []
137+ for name , tensor in cache_seq_interface .named_args .items ():
138+ if "hidden_states_cache" in name :
139+ hidden_state_buffers .append (tensor )
140+
141+ if not hidden_state_buffers :
142+ raise ValueError (
143+ "No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
144+ )
145+ return hidden_state_buffers
146+
147+ def prepare_hidden_states_capture (
148+ self , ordered_requests : RequestList , cache_seq_interface : CachedSequenceInterface
149+ ) -> None :
150+ """Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
151+ seq_lens = cache_seq_interface .info .seq_len
152+ num_tokens = sum (seq_lens )
153+
154+ start_idx = 0
155+ hidden_states_write_indices = []
156+ for request , seq_len in zip (ordered_requests , seq_lens ):
157+ request_id = request .request_id
158+ slot_id = self .slot_manager .get_slot (request_id )
159+ self .start_indices [slot_id ] = start_idx
160+ hidden_states_write_indices .extend (range (start_idx , start_idx + seq_len ))
161+ start_idx += max (seq_len , self .max_total_draft_tokens + 1 )
162+ assert start_idx < self .hidden_states .shape [0 ], (
163+ f"start_idx { start_idx } exceeds hidden_states capacity { self .hidden_states .shape [0 ]} "
164+ )
165+
166+ if len (hidden_states_write_indices ) != num_tokens :
167+ raise ValueError (
168+ f"len(hidden_state_write_indices) ({ len (hidden_states_write_indices )} ) != num_tokens \
169+ ({ num_tokens } ). Check whether ordered_requests matches up with seq_lens."
170+ )
171+
172+ hidden_state_write_indices_host = torch .tensor (
173+ hidden_states_write_indices , dtype = torch .long
174+ )
175+
176+ self .hidden_state_write_indices [:num_tokens ].copy_ (
177+ hidden_state_write_indices_host , non_blocking = True
178+ )
179+
180+ def capture_hidden_states (self , cache_seq_interface : CachedSequenceInterface ) -> None :
181+ """Capture configured hidden states that have been written by the model,
182+ in a format that can be used by the draft model.
183+ """
184+ full_hidden_states = self ._get_hidden_state_buffers (cache_seq_interface )
185+ if not full_hidden_states :
186+ return
187+
188+ num_tokens = sum (cache_seq_interface .info .seq_len )
189+
190+ hidden_states = [hidden_state [:num_tokens ] for hidden_state in full_hidden_states ]
191+ hidden_states = torch .cat (hidden_states , dim = 1 ) if hidden_states else None
192+ hidden_states = hidden_states .to (dtype = self .dtype )
193+
194+ token_idx = self .hidden_state_write_indices [:num_tokens ]
195+ self .hidden_states [:, : hidden_states .shape [1 ]].index_copy_ (0 , token_idx , hidden_states )
196+
197+
110198def construct_draft_llm_args (
111199 ad_config : LlmArgs ,
112200) -> TorchLlmArgs :
@@ -331,6 +419,10 @@ def _prepare_inputs(
331419 kv_cache_manager = resource_manager .get_resource_manager (
332420 ResourceManagerType .KV_CACHE_MANAGER
333421 )
422+ # resource manager for hidden state capture
423+ spec_resource_manager = resource_manager .get_resource_manager (
424+ ResourceManagerType .SPEC_RESOURCE_MANAGER
425+ )
334426
335427 # requests in order of context, generate
336428 context_requests = scheduled_requests .context_requests
@@ -341,6 +433,7 @@ def _prepare_inputs(
341433 r for r in scheduled_requests .generation_requests if get_draft_token_length (r ) == 0
342434 ]
343435 gen_requests = extend_requests + generation_requests
436+ ordered_requests = context_requests + gen_requests
344437 # info to be extracted
345438 input_ids : List [List [int ]] = []
346439 input_pos : List [int ] = []
@@ -481,17 +574,32 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
481574 scatter_ref = dummy_token ,
482575 )
483576
577+ if spec_resource_manager is not None and isinstance (
578+ spec_resource_manager , ADHiddenStateManager
579+ ):
580+ spec_resource_manager .prepare_hidden_states_capture (
581+ ordered_requests , self .cache_seq_interface
582+ )
583+
484584 self .iter_states ["num_ctx_requests" ] = num_ctx_requests
485585 self .iter_states ["num_ctx_tokens" ] = num_ctx_tokens
486586 # TODO: handle extend requests and draft requests for specdec
487587 self .iter_states ["num_generation_tokens" ] = num_generation_tokens
488588 return last_logit_only
489589
490590 @nvtx_range ("ad_compute_logits" )
491- def _compute_logits (self ) -> List [torch .Tensor ]:
591+ def _compute_logits (self , resource_manager : ResourceManager ) -> List [torch .Tensor ]:
492592 # run the model
493593 logits : torch .Tensor = self .model (** self .cache_seq_interface .named_args )[0 ]
494594
595+ spec_resource_manager = resource_manager .get_resource_manager (
596+ ResourceManagerType .SPEC_RESOURCE_MANAGER
597+ )
598+ if spec_resource_manager is not None and isinstance (
599+ spec_resource_manager , ADHiddenStateManager
600+ ):
601+ spec_resource_manager .capture_hidden_states (self .cache_seq_interface )
602+
495603 # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
496604 logits = logits .float ()
497605
@@ -518,7 +626,7 @@ def forward(
518626 self .iter_counter += 1
519627
520628 # compute all logits
521- logits = self ._compute_logits ()
629+ logits = self ._compute_logits (resource_manager )
522630
523631 # gather+cat logits
524632 logits_flat = torch .cat (
@@ -529,8 +637,30 @@ def forward(
529637 return {"logits" : logits_flat }
530638
531639
640+ def share_embedding_weights (
641+ target_model_engine : "ADEngine" , draft_model_engine : PyTorchModelEngine
642+ ):
643+ # This function is necessary for supporting Eagle and other speculative decoding methods that
644+ # copy the embed_tokens submodule. It is not necessary for MTP and other speculative decoding methods that
645+ # use the draft model engine directly.
646+
647+ submodule = target_model_engine .model .model .embed_tokens
648+
649+ world_size = mpi_world_size ()
650+ assert world_size <= 1 , f"This code assumes tp<=1. World size: { world_size } "
651+
652+ # Note: This simple forward function implementation assumes tp=1.
653+ # TODO(govind): Handle the tp>1 case.
654+ def new_embedding_forward (self , input_ids ):
655+ return F .embedding (input_ids , self .weight )
656+
657+ submodule .forward = MethodType (new_embedding_forward , submodule )
658+
659+ draft_model_engine .load_weights_from_target_model (target_model_engine .model )
660+
661+
532662def create_draft_model_engine_maybe (
533- ad_config : LlmArgs , engine , dist_mapping : Mapping , mpi_dist : MPIDist
663+ ad_config : LlmArgs , target_engine : ADEngine , dist_mapping : Mapping , mpi_dist : MPIDist
534664) -> Optional [PyTorchModelEngine ]:
535665 """Create a draft model engine for speculative decoding.
536666
@@ -558,14 +688,18 @@ def create_draft_model_engine_maybe(
558688 chunked_prefill = ad_config .enable_chunked_prefill ,
559689 cache_reuse = kv_cache_config .enable_block_reuse ,
560690 has_speculative_draft_tokens = has_spec_drafter ,
561- chunk_size = engine .llm_args .max_num_tokens ,
691+ chunk_size = target_engine .llm_args .max_num_tokens ,
562692 )
563693
564694 # Construct TorchLlmArgs for the draft model
565695 draft_llm_args = construct_draft_llm_args (
566696 ad_config = ad_config ,
567697 )
568698
699+ # chain drafter is not supported currently for AutoDeploy.
700+ # TODO(govind): Do this when we want to optimize 2-model spec dec performance.
701+ drafting_loop_wrapper = None
702+
569703 draft_model_engine = PyTorchModelEngine (
570704 model_path = draft_spec_config .speculative_model_dir ,
571705 llm_args = draft_llm_args ,
@@ -574,7 +708,11 @@ def create_draft_model_engine_maybe(
574708 dist = mpi_dist ,
575709 spec_config = draft_spec_config ,
576710 is_draft_model = True ,
577- drafting_loop_wrapper = None ,
711+ drafting_loop_wrapper = drafting_loop_wrapper ,
712+ )
713+
714+ share_embedding_weights (
715+ target_model_engine = target_engine , draft_model_engine = draft_model_engine
578716 )
579717
580718 draft_model_engine .kv_cache_manager_key = ResourceManagerType .DRAFT_KV_CACHE_MANAGER
@@ -668,21 +806,32 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
668806 engine = ADEngine .build_from_config (ad_config = ad_config )
669807
670808 spec_config = ad_config .speculative_config
671- if spec_config is not None and not spec_config .spec_dec_mode .is_draft_target ():
809+ if spec_config is not None and not (
810+ spec_config .spec_dec_mode .is_draft_target () or spec_config .spec_dec_mode .is_eagle3 ()
811+ ):
672812 raise ValueError (
673- "Currently, AutoDeploy only supports speculative decoding in draft target mode."
813+ "Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
674814 )
675815
676816 if spec_config is not None and ad_config .guided_decoding_backend is not None :
677817 raise ValueError (
678818 "Guided decoding is not currently supported for speculative decoding in AutoDeploy."
679819 )
680820
681- # Speculative resource manager not needed for DraftTargetDecoding.
682- spec_resource_manager = None
683-
684821 draft_model_engine = create_draft_model_engine_maybe (
685- ad_config = ad_config , engine = engine , dist_mapping = dist_mapping , mpi_dist = mpi_dist
822+ ad_config = ad_config , target_engine = engine , dist_mapping = dist_mapping , mpi_dist = mpi_dist
823+ )
824+
825+ spec_resource_manager = (
826+ ADHiddenStateManager (
827+ cache_seq_interface = engine .cache_seq_interface ,
828+ config = spec_config ,
829+ max_num_requests = ad_config .max_batch_size ,
830+ max_seq_len = engine .llm_args .max_seq_len ,
831+ max_num_tokens = engine .llm_args .max_num_tokens ,
832+ )
833+ if isinstance (spec_config , EagleDecodingConfig )
834+ else None
686835 )
687836
688837 # check kvcache config for partial block reuse
0 commit comments