3030from tensorrt_llm ._torch .pyexecutor .llm_request import get_draft_token_length
3131from tensorrt_llm ._torch .pyexecutor .py_executor_creator import get_guided_decoding_config
3232from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
33- from tensorrt_llm ._torch .speculative import _get_spec_resource_manager , get_spec_drafter
33+ from tensorrt_llm ._torch .speculative import get_spec_drafter
3434from tensorrt_llm ._torch .speculative .eagle3 import Eagle3ResourceManager
3535from tensorrt_llm ._utils import nvtx_range
3636from tensorrt_llm .llmapi .llm_args import (
@@ -111,6 +111,90 @@ def calculate_max_num_blocks(
111111 return self .num_blocks , 0
112112
113113
114+ class ADHiddenStateManager (Eagle3ResourceManager ):
115+ def __init__ (
116+ self ,
117+ cache_seq_interface : CachedSequenceInterface ,
118+ config : EagleDecodingConfig ,
119+ max_num_requests : int ,
120+ max_seq_len : int ,
121+ max_num_tokens : int ,
122+ ):
123+ hidden_state_buffer = self ._get_hidden_state_buffers (cache_seq_interface )[0 ]
124+ dtype = hidden_state_buffer .dtype
125+ hidden_size = hidden_state_buffer .shape [1 ]
126+
127+ super ().__init__ (config , dtype , hidden_size , max_num_requests , max_seq_len , max_num_tokens )
128+
129+ self .hidden_state_write_indices : torch .Tensor = torch .empty (
130+ max_num_tokens , dtype = torch .long , device = "cuda"
131+ )
132+
133+ def _get_hidden_state_buffers (
134+ self , cache_seq_interface : CachedSequenceInterface
135+ ) -> List [torch .Tensor ]:
136+ hidden_state_buffers = []
137+ for name , tensor in cache_seq_interface .named_args .items ():
138+ if "hidden_states_cache" in name :
139+ hidden_state_buffers .append (tensor )
140+
141+ if not hidden_state_buffers :
142+ raise ValueError (
143+ "No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
144+ )
145+ return hidden_state_buffers
146+
147+ def prepare_hidden_states_capture (
148+ self , ordered_requests : RequestList , cache_seq_interface : CachedSequenceInterface
149+ ) -> None :
150+ """Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
151+ seq_lens = cache_seq_interface .info .seq_len
152+ num_tokens = sum (seq_lens )
153+
154+ start_idx = 0
155+ hidden_states_write_indices = []
156+ for request , seq_len in zip (ordered_requests , seq_lens ):
157+ request_id = request .request_id
158+ slot_id = self .slot_manager .get_slot (request_id )
159+ self .start_indices [slot_id ] = start_idx
160+ hidden_states_write_indices .extend (range (start_idx , start_idx + seq_len ))
161+ start_idx += max (seq_len , self .max_total_draft_tokens + 1 )
162+ assert start_idx < self .hidden_states .shape [0 ], (
163+ f"start_idx { start_idx } exceeds hidden_states capacity { self .hidden_states .shape [0 ]} "
164+ )
165+
166+ if len (hidden_states_write_indices ) != num_tokens :
167+ raise ValueError (
168+ f"len(hidden_state_write_indices) ({ len (hidden_states_write_indices )} ) != num_tokens \
169+ ({ num_tokens } ). Check whether ordered_requests matches up with seq_lens."
170+ )
171+
172+ hidden_state_write_indices_host = torch .tensor (
173+ hidden_states_write_indices , dtype = torch .long
174+ )
175+
176+ self .hidden_state_write_indices [:num_tokens ].copy_ (
177+ hidden_state_write_indices_host , non_blocking = True
178+ )
179+
180+ def capture_hidden_states (self , cache_seq_interface : CachedSequenceInterface ) -> None :
181+ """Capture configured hidden states that have been written by the model,
182+ in a format that can be used by the draft model.
183+ """
184+ full_hidden_states = self ._get_hidden_state_buffers (cache_seq_interface )
185+ if not full_hidden_states :
186+ return
187+
188+ num_tokens = sum (cache_seq_interface .info .seq_len )
189+
190+ hidden_states = [hidden_state [:num_tokens ] for hidden_state in full_hidden_states ]
191+ hidden_states = torch .cat (hidden_states , dim = 1 ) if hidden_states else None
192+ hidden_states = hidden_states .to (dtype = self .dtype )
193+
194+ token_idx = self .hidden_state_write_indices [:num_tokens ]
195+ self .hidden_states [:, : hidden_states .shape [1 ]].index_copy_ (0 , token_idx , hidden_states )
196+
197+
114198def construct_draft_llm_args (
115199 ad_config : LlmArgs ,
116200) -> TorchLlmArgs :
@@ -360,48 +444,6 @@ def __init__(
360444 # start fresh with fixed seed
361445 torch .manual_seed (42 )
362446
363- def _prepare_hidden_state_capture (
364- self , ordered_requests : RequestList , resource_manager : ResourceManager
365- ) -> None :
366- spec_resource_manager = resource_manager .get_resource_manager (
367- ResourceManagerType .SPEC_RESOURCE_MANAGER
368- )
369- if spec_resource_manager is None or not isinstance (
370- spec_resource_manager , Eagle3ResourceManager
371- ):
372- return
373-
374- caches = []
375- for name , tensor in self .cache_seq_interface .named_args .items ():
376- if "hidden_states_cache" in name :
377- caches .append ((name , tensor ))
378-
379- seq_lens = self .cache_seq_interface .info .seq_len
380- num_tokens = sum (seq_lens )
381- max_total_draft_tokens = getattr (spec_resource_manager , "max_total_draft_tokens" , 0 )
382-
383- start_idx = 0
384- hidden_state_write_indices = []
385- for request , seq_len in zip (ordered_requests , seq_lens ):
386- request_id = request .request_id
387- slot_id = spec_resource_manager .slot_manager .get_slot (request_id )
388- spec_resource_manager .start_indices [slot_id ] = start_idx
389- hidden_state_write_indices .extend (range (start_idx , start_idx + seq_len ))
390- start_idx += max (seq_len , max_total_draft_tokens + 1 )
391- assert start_idx < spec_resource_manager .hidden_states .shape [0 ], (
392- f"start_idx { start_idx } exceeds hidden_states capacity { spec_resource_manager .hidden_states .shape [0 ]} "
393- )
394-
395- assert len (hidden_state_write_indices ) == num_tokens
396-
397- self .hidden_state_write_indices_host = torch .tensor (
398- hidden_state_write_indices , dtype = torch .long
399- )
400-
401- self .hidden_state_write_indices_gpu [:num_tokens ].copy_ (
402- self .hidden_state_write_indices_host , non_blocking = True
403- )
404-
405447 @nvtx_range ("ad_prepare_inputs" )
406448 def _prepare_inputs (
407449 self ,
@@ -414,7 +456,10 @@ def _prepare_inputs(
414456 kv_cache_manager = resource_manager .get_resource_manager (
415457 ResourceManagerType .KV_CACHE_MANAGER
416458 )
417-
459+ # resource manager for hidden state capture
460+ spec_resource_manager = resource_manager .get_resource_manager (
461+ ResourceManagerType .SPEC_RESOURCE_MANAGER
462+ )
418463 # requests in order of context, generate
419464 context_requests = scheduled_requests .context_requests
420465 extend_requests = [
@@ -425,7 +470,6 @@ def _prepare_inputs(
425470 ]
426471 gen_requests = extend_requests + generation_requests
427472 ordered_requests = context_requests + gen_requests
428-
429473 # info to be extracted
430474 input_ids : List [List [int ]] = []
431475 input_pos : List [int ] = []
@@ -566,58 +610,38 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
566610 scatter_ref = dummy_token ,
567611 )
568612
613+ if spec_resource_manager is not None and isinstance (
614+ spec_resource_manager , ADHiddenStateManager
615+ ):
616+ spec_resource_manager .prepare_hidden_states_capture (
617+ ordered_requests , self .cache_seq_interface
618+ )
619+
569620 self .iter_states ["num_ctx_requests" ] = num_ctx_requests
570621 self .iter_states ["num_ctx_tokens" ] = num_ctx_tokens
571622 # TODO: handle extend requests and draft requests for specdec
572623 self .iter_states ["num_generation_tokens" ] = num_generation_tokens
573624
574- self ._prepare_hidden_state_capture (ordered_requests , resource_manager )
575-
576625 return last_logit_only
577626
578627 @nvtx_range ("ad_compute_logits" )
579628 def _compute_logits (self , resource_manager : ResourceManager ) -> List [torch .Tensor ]:
580629 # run the model
581630 logits : torch .Tensor = self .model (** self .cache_seq_interface .named_args )[0 ]
582- self ._capture_hidden_states_cache (resource_manager )
583-
584- # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
585- logits = logits .float ()
586-
587- # return a list of tensors
588- return self .cache_seq_interface .info .unnest_sequences (logits )
589631
590- def _capture_hidden_states_cache (self , resource_manager : ResourceManager ) -> None :
591- """Capture and print hidden_states_cache tensor passed to the model."""
592632 spec_resource_manager = resource_manager .get_resource_manager (
593633 ResourceManagerType .SPEC_RESOURCE_MANAGER
594634 )
595- if spec_resource_manager is None or not isinstance (
596- spec_resource_manager , Eagle3ResourceManager
635+ if spec_resource_manager is not None and isinstance (
636+ spec_resource_manager , ADHiddenStateManager
597637 ):
598- return
638+ spec_resource_manager . capture_hidden_states ( self . cache_seq_interface )
599639
600- caches = []
601- for name , tensor in self .cache_seq_interface .named_args .items ():
602- if "hidden_states_cache" in name :
603- caches .append ((name , tensor ))
604-
605- if not caches :
606- return
607-
608- seq_lens = self .cache_seq_interface .info .seq_len
609- num_tokens = sum (seq_lens )
610-
611- used_caches = [cache [:num_tokens ] for _ , cache in caches ]
612-
613- eagle3_hidden_states = spec_resource_manager .hidden_states
614- hidden_states_cache_value = torch .cat (used_caches , dim = 1 ) if used_caches else None
615- hidden_states_cache_value = hidden_states_cache_value .to (dtype = eagle3_hidden_states .dtype )
640+ # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
641+ logits = logits .float ()
616642
617- token_idx = self .hidden_state_write_indices_gpu [:num_tokens ]
618- eagle3_hidden_states [:, : hidden_states_cache_value .shape [1 ]].index_copy_ (
619- 0 , token_idx , hidden_states_cache_value
620- )
643+ # return a list of tensors
644+ return self .cache_seq_interface .info .unnest_sequences (logits )
621645
622646 def get_max_num_sequences (self ) -> int :
623647 """Maximum number of sequences supported by the engine."""
@@ -837,15 +861,16 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
837861 ad_config = ad_config , target_engine = engine , dist_mapping = dist_mapping , mpi_dist = mpi_dist
838862 )
839863
840- target_model_dtype = torch .bfloat16 # TODO: Get this from the model engine.
841- target_hidden_size = 4096 # TODO: Get this from the model engine.
842-
843- spec_resource_manager = _get_spec_resource_manager (
844- target_model_engine = engine ,
845- max_seq_len = engine .llm_args .max_seq_len ,
846- model_dtype = target_model_dtype ,
847- hidden_size = target_hidden_size ,
848- draft_model_engine = draft_model_engine ,
864+ spec_resource_manager = (
865+ ADHiddenStateManager (
866+ cache_seq_interface = engine .cache_seq_interface ,
867+ config = spec_config ,
868+ max_num_requests = ad_config .max_batch_size ,
869+ max_seq_len = engine .llm_args .max_seq_len ,
870+ max_num_tokens = engine .llm_args .max_num_tokens ,
871+ )
872+ if isinstance (spec_config , EagleDecodingConfig )
873+ else None
849874 )
850875
851876 # check kvcache config for partial block reuse
0 commit comments