Skip to content

Commit 4e59e4d

Browse files
made ADHiddenStatesManager class which extracts hidden size and dtype
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 023c7fa commit 4e59e4d

File tree

3 files changed

+130
-120
lines changed

3 files changed

+130
-120
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 113 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length
3131
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
3232
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
33-
from tensorrt_llm._torch.speculative import _get_spec_resource_manager, get_spec_drafter
33+
from tensorrt_llm._torch.speculative import get_spec_drafter
3434
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager
3535
from tensorrt_llm._utils import nvtx_range
3636
from tensorrt_llm.llmapi.llm_args import (
@@ -111,6 +111,90 @@ def calculate_max_num_blocks(
111111
return self.num_blocks, 0
112112

113113

114+
class ADHiddenStateManager(Eagle3ResourceManager):
115+
def __init__(
116+
self,
117+
cache_seq_interface: CachedSequenceInterface,
118+
config: EagleDecodingConfig,
119+
max_num_requests: int,
120+
max_seq_len: int,
121+
max_num_tokens: int,
122+
):
123+
hidden_state_buffer = self._get_hidden_state_buffers(cache_seq_interface)[0]
124+
dtype = hidden_state_buffer.dtype
125+
hidden_size = hidden_state_buffer.shape[1]
126+
127+
super().__init__(config, dtype, hidden_size, max_num_requests, max_seq_len, max_num_tokens)
128+
129+
self.hidden_state_write_indices: torch.Tensor = torch.empty(
130+
max_num_tokens, dtype=torch.long, device="cuda"
131+
)
132+
133+
def _get_hidden_state_buffers(
134+
self, cache_seq_interface: CachedSequenceInterface
135+
) -> List[torch.Tensor]:
136+
hidden_state_buffers = []
137+
for name, tensor in cache_seq_interface.named_args.items():
138+
if "hidden_states_cache" in name:
139+
hidden_state_buffers.append(tensor)
140+
141+
if not hidden_state_buffers:
142+
raise ValueError(
143+
"No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
144+
)
145+
return hidden_state_buffers
146+
147+
def prepare_hidden_states_capture(
148+
self, ordered_requests: RequestList, cache_seq_interface: CachedSequenceInterface
149+
) -> None:
150+
"""Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
151+
seq_lens = cache_seq_interface.info.seq_len
152+
num_tokens = sum(seq_lens)
153+
154+
start_idx = 0
155+
hidden_states_write_indices = []
156+
for request, seq_len in zip(ordered_requests, seq_lens):
157+
request_id = request.request_id
158+
slot_id = self.slot_manager.get_slot(request_id)
159+
self.start_indices[slot_id] = start_idx
160+
hidden_states_write_indices.extend(range(start_idx, start_idx + seq_len))
161+
start_idx += max(seq_len, self.max_total_draft_tokens + 1)
162+
assert start_idx < self.hidden_states.shape[0], (
163+
f"start_idx {start_idx} exceeds hidden_states capacity {self.hidden_states.shape[0]}"
164+
)
165+
166+
if len(hidden_states_write_indices) != num_tokens:
167+
raise ValueError(
168+
f"len(hidden_state_write_indices) ({len(hidden_states_write_indices)}) != num_tokens \
169+
({num_tokens}). Check whether ordered_requests matches up with seq_lens."
170+
)
171+
172+
hidden_state_write_indices_host = torch.tensor(
173+
hidden_states_write_indices, dtype=torch.long
174+
)
175+
176+
self.hidden_state_write_indices[:num_tokens].copy_(
177+
hidden_state_write_indices_host, non_blocking=True
178+
)
179+
180+
def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) -> None:
181+
"""Capture configured hidden states that have been written by the model,
182+
in a format that can be used by the draft model.
183+
"""
184+
full_hidden_states = self._get_hidden_state_buffers(cache_seq_interface)
185+
if not full_hidden_states:
186+
return
187+
188+
num_tokens = sum(cache_seq_interface.info.seq_len)
189+
190+
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
191+
hidden_states = torch.cat(hidden_states, dim=1) if hidden_states else None
192+
hidden_states = hidden_states.to(dtype=self.dtype)
193+
194+
token_idx = self.hidden_state_write_indices[:num_tokens]
195+
self.hidden_states[:, : hidden_states.shape[1]].index_copy_(0, token_idx, hidden_states)
196+
197+
114198
def construct_draft_llm_args(
115199
ad_config: LlmArgs,
116200
) -> TorchLlmArgs:
@@ -360,48 +444,6 @@ def __init__(
360444
# start fresh with fixed seed
361445
torch.manual_seed(42)
362446

363-
def _prepare_hidden_state_capture(
364-
self, ordered_requests: RequestList, resource_manager: ResourceManager
365-
) -> None:
366-
spec_resource_manager = resource_manager.get_resource_manager(
367-
ResourceManagerType.SPEC_RESOURCE_MANAGER
368-
)
369-
if spec_resource_manager is None or not isinstance(
370-
spec_resource_manager, Eagle3ResourceManager
371-
):
372-
return
373-
374-
caches = []
375-
for name, tensor in self.cache_seq_interface.named_args.items():
376-
if "hidden_states_cache" in name:
377-
caches.append((name, tensor))
378-
379-
seq_lens = self.cache_seq_interface.info.seq_len
380-
num_tokens = sum(seq_lens)
381-
max_total_draft_tokens = getattr(spec_resource_manager, "max_total_draft_tokens", 0)
382-
383-
start_idx = 0
384-
hidden_state_write_indices = []
385-
for request, seq_len in zip(ordered_requests, seq_lens):
386-
request_id = request.request_id
387-
slot_id = spec_resource_manager.slot_manager.get_slot(request_id)
388-
spec_resource_manager.start_indices[slot_id] = start_idx
389-
hidden_state_write_indices.extend(range(start_idx, start_idx + seq_len))
390-
start_idx += max(seq_len, max_total_draft_tokens + 1)
391-
assert start_idx < spec_resource_manager.hidden_states.shape[0], (
392-
f"start_idx {start_idx} exceeds hidden_states capacity {spec_resource_manager.hidden_states.shape[0]}"
393-
)
394-
395-
assert len(hidden_state_write_indices) == num_tokens
396-
397-
self.hidden_state_write_indices_host = torch.tensor(
398-
hidden_state_write_indices, dtype=torch.long
399-
)
400-
401-
self.hidden_state_write_indices_gpu[:num_tokens].copy_(
402-
self.hidden_state_write_indices_host, non_blocking=True
403-
)
404-
405447
@nvtx_range("ad_prepare_inputs")
406448
def _prepare_inputs(
407449
self,
@@ -414,7 +456,10 @@ def _prepare_inputs(
414456
kv_cache_manager = resource_manager.get_resource_manager(
415457
ResourceManagerType.KV_CACHE_MANAGER
416458
)
417-
459+
# resource manager for hidden state capture
460+
spec_resource_manager = resource_manager.get_resource_manager(
461+
ResourceManagerType.SPEC_RESOURCE_MANAGER
462+
)
418463
# requests in order of context, generate
419464
context_requests = scheduled_requests.context_requests
420465
extend_requests = [
@@ -425,7 +470,6 @@ def _prepare_inputs(
425470
]
426471
gen_requests = extend_requests + generation_requests
427472
ordered_requests = context_requests + gen_requests
428-
429473
# info to be extracted
430474
input_ids: List[List[int]] = []
431475
input_pos: List[int] = []
@@ -566,58 +610,38 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
566610
scatter_ref=dummy_token,
567611
)
568612

613+
if spec_resource_manager is not None and isinstance(
614+
spec_resource_manager, ADHiddenStateManager
615+
):
616+
spec_resource_manager.prepare_hidden_states_capture(
617+
ordered_requests, self.cache_seq_interface
618+
)
619+
569620
self.iter_states["num_ctx_requests"] = num_ctx_requests
570621
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
571622
# TODO: handle extend requests and draft requests for specdec
572623
self.iter_states["num_generation_tokens"] = num_generation_tokens
573624

574-
self._prepare_hidden_state_capture(ordered_requests, resource_manager)
575-
576625
return last_logit_only
577626

578627
@nvtx_range("ad_compute_logits")
579628
def _compute_logits(self, resource_manager: ResourceManager) -> List[torch.Tensor]:
580629
# run the model
581630
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
582-
self._capture_hidden_states_cache(resource_manager)
583-
584-
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
585-
logits = logits.float()
586-
587-
# return a list of tensors
588-
return self.cache_seq_interface.info.unnest_sequences(logits)
589631

590-
def _capture_hidden_states_cache(self, resource_manager: ResourceManager) -> None:
591-
"""Capture and print hidden_states_cache tensor passed to the model."""
592632
spec_resource_manager = resource_manager.get_resource_manager(
593633
ResourceManagerType.SPEC_RESOURCE_MANAGER
594634
)
595-
if spec_resource_manager is None or not isinstance(
596-
spec_resource_manager, Eagle3ResourceManager
635+
if spec_resource_manager is not None and isinstance(
636+
spec_resource_manager, ADHiddenStateManager
597637
):
598-
return
638+
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
599639

600-
caches = []
601-
for name, tensor in self.cache_seq_interface.named_args.items():
602-
if "hidden_states_cache" in name:
603-
caches.append((name, tensor))
604-
605-
if not caches:
606-
return
607-
608-
seq_lens = self.cache_seq_interface.info.seq_len
609-
num_tokens = sum(seq_lens)
610-
611-
used_caches = [cache[:num_tokens] for _, cache in caches]
612-
613-
eagle3_hidden_states = spec_resource_manager.hidden_states
614-
hidden_states_cache_value = torch.cat(used_caches, dim=1) if used_caches else None
615-
hidden_states_cache_value = hidden_states_cache_value.to(dtype=eagle3_hidden_states.dtype)
640+
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
641+
logits = logits.float()
616642

617-
token_idx = self.hidden_state_write_indices_gpu[:num_tokens]
618-
eagle3_hidden_states[:, : hidden_states_cache_value.shape[1]].index_copy_(
619-
0, token_idx, hidden_states_cache_value
620-
)
643+
# return a list of tensors
644+
return self.cache_seq_interface.info.unnest_sequences(logits)
621645

622646
def get_max_num_sequences(self) -> int:
623647
"""Maximum number of sequences supported by the engine."""
@@ -837,15 +861,16 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
837861
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
838862
)
839863

840-
target_model_dtype = torch.bfloat16 # TODO: Get this from the model engine.
841-
target_hidden_size = 4096 # TODO: Get this from the model engine.
842-
843-
spec_resource_manager = _get_spec_resource_manager(
844-
target_model_engine=engine,
845-
max_seq_len=engine.llm_args.max_seq_len,
846-
model_dtype=target_model_dtype,
847-
hidden_size=target_hidden_size,
848-
draft_model_engine=draft_model_engine,
864+
spec_resource_manager = (
865+
ADHiddenStateManager(
866+
cache_seq_interface=engine.cache_seq_interface,
867+
config=spec_config,
868+
max_num_requests=ad_config.max_batch_size,
869+
max_seq_len=engine.llm_args.max_seq_len,
870+
max_num_tokens=engine.llm_args.max_num_tokens,
871+
)
872+
if isinstance(spec_config, EagleDecodingConfig)
873+
else None
849874
)
850875

851876
# check kvcache config for partial block reuse

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from .ngram import NGramDrafter, NGramPoolManager
66
from .save_hidden_state import SaveHiddenStatesDrafter
77
from .spec_tree_manager import SpecTreeManager
8-
from .utils import (_get_spec_resource_manager, get_num_extra_kv_tokens,
9-
get_num_spec_layers, get_spec_decoder, get_spec_drafter,
10-
get_spec_metadata, get_spec_resource_manager,
11-
get_spec_worker, update_spec_config_from_model_config)
8+
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
9+
get_spec_decoder, get_spec_drafter, get_spec_metadata,
10+
get_spec_resource_manager, get_spec_worker,
11+
update_spec_config_from_model_config)
1212

1313
__all__ = [
1414
"Eagle3SpecMetadata",
@@ -25,7 +25,6 @@
2525
"get_spec_drafter",
2626
"get_spec_metadata",
2727
"get_spec_resource_manager",
28-
"_get_spec_resource_manager",
2928
"get_spec_worker",
3029
"update_spec_config_from_model_config",
3130
"suggest_spec_config",

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -108,50 +108,47 @@ def get_spec_metadata(spec_config,
108108
return None
109109

110110

111-
def _get_spec_resource_manager(
112-
target_model_engine: "ModelEngine",
113-
max_seq_len: int,
114-
model_dtype: torch.dtype,
115-
hidden_size: int,
116-
draft_model_engine: Optional["PyTorchModelEngine"] = None):
117-
spec_config = target_model_engine.spec_config
111+
def get_spec_resource_manager(model_engine, draft_model_engine=None):
112+
spec_config = model_engine.spec_config
118113
if spec_config is None:
119114
return None
120-
max_num_requests = target_model_engine.batch_size
121-
max_num_tokens = target_model_engine.llm_args.max_num_tokens
115+
model_config = model_engine.model.config
116+
max_num_requests = model_engine.batch_size
117+
max_seq_len = model_engine.max_seq_len
118+
max_num_tokens = model_engine.max_num_tokens
122119
spec_dec_mode = spec_config.spec_dec_mode
123120
if spec_dec_mode.is_mtp_eagle_one_model():
124121
if spec_config.use_relaxed_acceptance_for_thinking:
125122
return MTPHiddenStatesManager(
126123
spec_config,
127-
model_dtype,
128-
hidden_size,
124+
model_config.torch_dtype,
125+
model_config.hidden_size,
129126
max_num_requests,
130127
)
131128
else:
132129
return None
133130
if spec_dec_mode.is_mtp_one_model():
134131
return MTPHiddenStatesManager(
135132
spec_config,
136-
model_dtype,
137-
hidden_size,
133+
model_config.torch_dtype,
134+
model_config.hidden_size,
138135
max_num_requests,
139136
)
140137
if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle():
141138
assert draft_model_engine is not None, "Draft model engine is required for Eagle3 and MTP Eagle two model flow."
142139
return Eagle3ResourceManager(
143140
spec_config,
144141
draft_model_engine.model.config.torch_dtype,
145-
hidden_size,
142+
model_config.hidden_size,
146143
max_num_requests,
147144
max_seq_len,
148145
max_num_tokens,
149146
)
150147
if spec_dec_mode.is_save_hidden_states():
151148
return Eagle3ResourceManager(
152149
spec_config,
153-
model_dtype,
154-
hidden_size,
150+
model_engine.model.config.torch_dtype,
151+
model_config.hidden_size,
155152
max_num_requests,
156153
max_seq_len,
157154
max_num_tokens,
@@ -163,17 +160,6 @@ def _get_spec_resource_manager(
163160
return None
164161

165162

166-
def get_spec_resource_manager(
167-
model_engine: "PyTorchModelEngine",
168-
draft_model_engine: Optional["PyTorchModelEngine"] = None):
169-
return _get_spec_resource_manager(
170-
target_model_engine=model_engine,
171-
max_seq_len=model_engine.max_seq_len,
172-
model_dtype=model_engine.model.config.torch_dtype,
173-
hidden_size=model_engine.model.config.hidden_size,
174-
draft_model_engine=draft_model_engine)
175-
176-
177163
def get_spec_decoder(sampler_args: TorchSampler.Args,
178164
spec_config: "DecodingBaseConfig"):
179165
if spec_config.spec_dec_mode.is_mtp_one_model():

0 commit comments

Comments
 (0)