Skip to content

Commit a9eb5af

Browse files
[#9241][feat] AutoDeploy: Support Eagle3 Speculative Decoding (#9869)
Support two model flow with no overlap scheduler or chain drafter. Drafting model is in PyTorch backend. Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 1f8ed71 commit a9eb5af

File tree

9 files changed

+671
-56
lines changed

9 files changed

+671
-56
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ transforms:
7575
stage: pattern_matcher
7676
quantize_mxfp4_moe:
7777
stage: pattern_matcher
78+
detect_hidden_states_for_capture:
79+
stage: pattern_matcher
7880
detect_sharding:
7981
stage: sharding
8082
simple_shard_only: false
@@ -163,6 +165,9 @@ transforms:
163165
insert_cached_delta_rule:
164166
stage: cache_init
165167
backend: fla_delta
168+
insert_cached_residual_add:
169+
stage: cache_init
170+
backend: cached_residual_add
166171
initialize_cache:
167172
stage: cache_init
168173
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88

99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

11-
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
11+
from ...llmapi.llm_args import (
12+
BaseLlmArgs,
13+
BuildConfig,
14+
EagleDecodingConfig,
15+
KvCacheConfig,
16+
SamplerType,
17+
_ParallelConfig,
18+
)
1219
from .models import ModelFactory, ModelFactoryRegistry
1320
from .utils._config import DynamicYamlMixInForSettings
1421
from .utils.logger import ad_logger
@@ -150,6 +157,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
150157

151158
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
152159

160+
draft_checkpoint_loader: Optional[object] = Field(
161+
default=None,
162+
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
163+
)
164+
153165
### INFERENCE OPTIMIZER CONFIG #################################################################
154166
mode: Literal["graph", "transformers"] = Field(
155167
default="graph",
@@ -190,11 +202,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
190202
),
191203
)
192204

193-
draft_checkpoint_loader: Optional[object] = Field(
194-
default=None,
195-
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
196-
)
197-
198205
### SEQUENCE INTERFACE CONFIG ##################################################################
199206
max_input_len: int = Field(default=1024, description="The maximum input length.")
200207
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
@@ -420,6 +427,19 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A
420427
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
421428
return _check_for_default_value_only(cls, value, info, msg)
422429

430+
@model_validator(mode="after")
431+
def setup_hidden_state_capture(self):
432+
if self.speculative_config is None or not isinstance(
433+
self.speculative_config, EagleDecodingConfig
434+
):
435+
return self
436+
437+
self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True
438+
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
439+
self.speculative_config.eagle3_layers_to_capture
440+
)
441+
return self
442+
423443
@model_validator(mode="after")
424444
def validate_parallel_config(self):
425445
"""Setup parallel config according to world_size.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 192 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
import types
1414
from collections import defaultdict
1515
from dataclasses import dataclass
16-
from types import SimpleNamespace
16+
from types import MethodType, SimpleNamespace
1717
from typing import Dict, List, Optional, Tuple
1818

1919
import torch
20+
import torch.nn.functional as F
2021
from strenum import StrEnum
2122
from torch._prims_common import DeviceLikeType
2223

2324
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
25+
from tensorrt_llm._torch.auto_deploy.utils._graph import get_input_embeddings, get_lm_head_weights
26+
from tensorrt_llm._torch.models.modeling_speculative import Eagle3ForCausalLM
2427
from tensorrt_llm._torch.pyexecutor._util import (
2528
_create_kv_cache_manager,
2629
get_decoding_mode,
@@ -32,9 +35,11 @@
3235
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
3336
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
3437
from tensorrt_llm._torch.speculative import get_spec_drafter
38+
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager
3539
from tensorrt_llm._utils import nvtx_range
3640
from tensorrt_llm.llmapi.llm_args import (
3741
ContextChunkingPolicy,
42+
EagleDecodingConfig,
3843
LoadFormat,
3944
SamplerType,
4045
TorchLlmArgs,
@@ -57,6 +62,7 @@
5762
from ...pyexecutor.scheduler import (
5863
BindCapacityScheduler,
5964
BindMicroBatchScheduler,
65+
RequestList,
6066
ScheduledRequests,
6167
SimpleScheduler,
6268
)
@@ -113,6 +119,90 @@ def calculate_max_num_blocks(
113119
return self.num_blocks, 0
114120

115121

122+
class ADHiddenStateManager(Eagle3ResourceManager):
123+
def __init__(
124+
self,
125+
cache_seq_interface: CachedSequenceInterface,
126+
config: EagleDecodingConfig,
127+
max_num_requests: int,
128+
max_seq_len: int,
129+
max_num_tokens: int,
130+
):
131+
hidden_state_buffer = self._get_hidden_state_buffers(cache_seq_interface)[0]
132+
dtype = hidden_state_buffer.dtype
133+
hidden_size = hidden_state_buffer.shape[1]
134+
135+
super().__init__(config, dtype, hidden_size, max_num_requests, max_seq_len, max_num_tokens)
136+
137+
self.hidden_state_write_indices: torch.Tensor = torch.empty(
138+
max_num_tokens, dtype=torch.long, device="cuda"
139+
)
140+
141+
def _get_hidden_state_buffers(
142+
self, cache_seq_interface: CachedSequenceInterface
143+
) -> List[torch.Tensor]:
144+
hidden_state_buffers = []
145+
for name, tensor in cache_seq_interface.named_args.items():
146+
if "hidden_states_cache" in name:
147+
hidden_state_buffers.append(tensor)
148+
149+
if not hidden_state_buffers:
150+
raise ValueError(
151+
"No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
152+
)
153+
return hidden_state_buffers
154+
155+
def prepare_hidden_states_capture(
156+
self, ordered_requests: RequestList, cache_seq_interface: CachedSequenceInterface
157+
) -> None:
158+
"""Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
159+
seq_lens = cache_seq_interface.info.seq_len
160+
num_tokens = sum(seq_lens)
161+
162+
start_idx = 0
163+
hidden_states_write_indices = []
164+
for request, seq_len in zip(ordered_requests, seq_lens):
165+
request_id = request.request_id
166+
slot_id = self.slot_manager.get_slot(request_id)
167+
self.start_indices[slot_id] = start_idx
168+
hidden_states_write_indices.extend(range(start_idx, start_idx + seq_len))
169+
start_idx += max(seq_len, self.max_total_draft_tokens + 1)
170+
assert start_idx < self.hidden_states.shape[0], (
171+
f"start_idx {start_idx} exceeds hidden_states capacity {self.hidden_states.shape[0]}"
172+
)
173+
174+
if len(hidden_states_write_indices) != num_tokens:
175+
raise ValueError(
176+
f"len(hidden_state_write_indices) ({len(hidden_states_write_indices)}) != num_tokens \
177+
({num_tokens}). Check whether ordered_requests matches up with seq_lens."
178+
)
179+
180+
hidden_state_write_indices_host = torch.tensor(
181+
hidden_states_write_indices, dtype=torch.long
182+
)
183+
184+
self.hidden_state_write_indices[:num_tokens].copy_(
185+
hidden_state_write_indices_host, non_blocking=True
186+
)
187+
188+
def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) -> None:
189+
"""Capture configured hidden states that have been written by the model,
190+
in a format that can be used by the draft model.
191+
"""
192+
full_hidden_states = self._get_hidden_state_buffers(cache_seq_interface)
193+
if not full_hidden_states:
194+
return
195+
196+
num_tokens = sum(cache_seq_interface.info.seq_len)
197+
198+
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
199+
hidden_states = torch.cat(hidden_states, dim=1)
200+
hidden_states = hidden_states.to(dtype=self.dtype)
201+
202+
token_idx = self.hidden_state_write_indices[:num_tokens]
203+
self.hidden_states[:, : hidden_states.shape[1]].index_copy_(0, token_idx, hidden_states)
204+
205+
116206
def construct_draft_llm_args(
117207
ad_config: LlmArgs,
118208
) -> TorchLlmArgs:
@@ -461,6 +551,10 @@ def _prepare_inputs(
461551
kv_cache_manager = resource_manager.get_resource_manager(
462552
ResourceManagerType.KV_CACHE_MANAGER
463553
)
554+
# resource manager for hidden state capture
555+
spec_resource_manager = resource_manager.get_resource_manager(
556+
ResourceManagerType.SPEC_RESOURCE_MANAGER
557+
)
464558

465559
# requests in order of context, generate
466560
context_requests = scheduled_requests.context_requests
@@ -471,6 +565,7 @@ def _prepare_inputs(
471565
r for r in scheduled_requests.generation_requests if get_draft_token_length(r) == 0
472566
]
473567
gen_requests = extend_requests + generation_requests
568+
ordered_requests = context_requests + gen_requests
474569
# info to be extracted
475570
input_ids: List[List[int]] = []
476571
position_ids: List[List[int]] = []
@@ -670,6 +765,13 @@ def _build_input_ids(request) -> Tuple[List[int], List[int], bool]:
670765

671766
self.cache_seq_interface.info.run_host_prepare_for_attention_forward()
672767

768+
if spec_resource_manager is not None and isinstance(
769+
spec_resource_manager, ADHiddenStateManager
770+
):
771+
spec_resource_manager.prepare_hidden_states_capture(
772+
ordered_requests, self.cache_seq_interface
773+
)
774+
673775
self.iter_states["num_ctx_requests"] = num_ctx_requests
674776
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
675777
# TODO: handle extend requests and draft requests for specdec
@@ -710,14 +812,74 @@ def forward(
710812
outputs = {
711813
"logits": self._compute_logits(),
712814
}
815+
816+
# save hidden states after running model.forward() in _compute_logits()
817+
spec_resource_manager = resource_manager.get_resource_manager(
818+
ResourceManagerType.SPEC_RESOURCE_MANAGER
819+
)
820+
if spec_resource_manager is not None and isinstance(
821+
spec_resource_manager, ADHiddenStateManager
822+
):
823+
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
824+
713825
if self.mapping is not None:
714826
self._execute_logit_post_processors(scheduled_requests, outputs)
715827

716828
return outputs
717829

718830

831+
def share_target_weights_with_draft(
832+
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
833+
):
834+
"""
835+
Certain speculative decoding methods (e.g. Eagle3) require sharing the target model's embedding and lm_head weights
836+
with the draft model. This function does this sharing if necessary.
837+
"""
838+
839+
assert isinstance(draft_model_engine.model, Eagle3ForCausalLM), (
840+
f"Expected draft_model_engine.model to be Eagle3ForCausalLM, got {type(draft_model_engine.model)}"
841+
)
842+
843+
def share_embedding_weights_with_draft(
844+
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
845+
):
846+
embedding_weight = get_input_embeddings(target_model_engine.model)
847+
848+
world_size = mpi_world_size()
849+
assert world_size <= 1, f"This code assumes tp<=1. World size: {world_size}"
850+
851+
# Note: This simple forward function implementation assumes tp=1.
852+
# TODO(govind): Handle the tp>1 case.
853+
def new_embedding_forward(self, input_ids):
854+
return F.embedding(input_ids, self.weight)
855+
856+
if draft_model_engine.model.model.embed_tokens is None:
857+
submodule = torch.nn.Module()
858+
submodule.forward = MethodType(new_embedding_forward, submodule)
859+
submodule.weight = embedding_weight
860+
draft_model_engine.model.model.embed_tokens = submodule
861+
862+
def share_lm_head_weights_with_draft(
863+
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
864+
):
865+
vocab_size = target_model_engine.cache_seq_interface.info.vocab_size_padded
866+
867+
lm_head_weight = get_lm_head_weights(target_model_engine.model)
868+
869+
assert lm_head_weight.shape[0] == vocab_size, (
870+
f"Expected lm_head weight first dimension to be vocab_size={vocab_size}, "
871+
f"but got shape {lm_head_weight.shape}"
872+
)
873+
874+
if draft_model_engine.model.load_lm_head_from_target:
875+
draft_model_engine.model.lm_head.weight = lm_head_weight
876+
877+
share_embedding_weights_with_draft(target_model_engine, draft_model_engine)
878+
share_lm_head_weights_with_draft(target_model_engine, draft_model_engine)
879+
880+
719881
def create_draft_model_engine_maybe(
720-
ad_config: LlmArgs, engine, dist_mapping: Mapping, mpi_dist: MPIDist
882+
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
721883
) -> Optional[PyTorchModelEngine]:
722884
"""Create a draft model engine for speculative decoding.
723885
@@ -745,14 +907,18 @@ def create_draft_model_engine_maybe(
745907
chunked_prefill=ad_config.enable_chunked_prefill,
746908
cache_reuse=kv_cache_config.enable_block_reuse,
747909
has_speculative_draft_tokens=has_spec_drafter,
748-
chunk_size=engine.llm_args.max_num_tokens,
910+
chunk_size=target_engine.llm_args.max_num_tokens,
749911
)
750912

751913
# Construct TorchLlmArgs for the draft model
752914
draft_llm_args = construct_draft_llm_args(
753915
ad_config=ad_config,
754916
)
755917

918+
# chain drafter is not supported currently for AutoDeploy.
919+
# TODO(govind): Do this when we want to optimize 2-model spec dec performance.
920+
drafting_loop_wrapper = None
921+
756922
draft_model_engine = PyTorchModelEngine(
757923
model_path=draft_spec_config.speculative_model_dir,
758924
llm_args=draft_llm_args,
@@ -761,9 +927,14 @@ def create_draft_model_engine_maybe(
761927
dist=mpi_dist,
762928
spec_config=draft_spec_config,
763929
is_draft_model=True,
764-
drafting_loop_wrapper=None,
930+
drafting_loop_wrapper=drafting_loop_wrapper,
765931
)
766932

933+
if draft_spec_config.spec_dec_mode.is_eagle3():
934+
share_target_weights_with_draft(
935+
target_model_engine=target_engine, draft_model_engine=draft_model_engine
936+
)
937+
767938
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
768939

769940
return draft_model_engine
@@ -855,21 +1026,32 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
8551026
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
8561027

8571028
spec_config = ad_config.speculative_config
858-
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():
1029+
if spec_config is not None and not (
1030+
spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3()
1031+
):
8591032
raise ValueError(
860-
"Currently, AutoDeploy only supports speculative decoding in draft target mode."
1033+
"Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
8611034
)
8621035

8631036
if spec_config is not None and ad_config.guided_decoding_backend is not None:
8641037
raise ValueError(
8651038
"Guided decoding is not currently supported for speculative decoding in AutoDeploy."
8661039
)
8671040

868-
# Speculative resource manager not needed for DraftTargetDecoding.
869-
spec_resource_manager = None
870-
8711041
draft_model_engine = create_draft_model_engine_maybe(
872-
ad_config=ad_config, engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
1042+
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
1043+
)
1044+
1045+
spec_resource_manager = (
1046+
ADHiddenStateManager(
1047+
cache_seq_interface=engine.cache_seq_interface,
1048+
config=spec_config,
1049+
max_num_requests=ad_config.max_batch_size,
1050+
max_seq_len=engine.llm_args.max_seq_len,
1051+
max_num_tokens=engine.llm_args.max_num_tokens,
1052+
)
1053+
if isinstance(spec_config, EagleDecodingConfig)
1054+
else None
8731055
)
8741056

8751057
# check kvcache config for partial block reuse

0 commit comments

Comments
 (0)