Skip to content

Commit 6c80b20

Browse files
Init commit for EagleDecodingConfig support in AutoDeploy. Adds: file for running AD with Eagle1, modifies existing TRTLLM example to use Eagle1 instead of Eagle3 for comparison. Eagle1 is chosen (last hidden layer only) to resemble MTPEagle
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 2c0293c commit 6c80b20

File tree

8 files changed

+535
-43
lines changed

8 files changed

+535
-43
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ transforms:
7575
stage: pattern_matcher
7676
quantize_mxfp4_moe:
7777
stage: pattern_matcher
78+
detect_hidden_states_for_capture:
79+
stage: pattern_matcher
7880
detect_sharding:
7981
stage: sharding
8082
simple_shard_only: false
@@ -160,6 +162,9 @@ transforms:
160162
insert_cached_delta_rule:
161163
stage: cache_init
162164
backend: fla_delta
165+
insert_cached_residual_add:
166+
stage: cache_init
167+
backend: cached_residual_add
163168
initialize_cache:
164169
stage: cache_init
165170
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from importlib.resources import files
22
from pathlib import Path
3-
from typing import Any, Dict, List, Literal, Optional, Type, Union
3+
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
44

55
import torch
66
from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
77
from pydantic_settings import BaseSettings, SettingsConfigDict
88

99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

11-
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
11+
from ...llmapi.llm_args import (
12+
BaseLlmArgs,
13+
BuildConfig,
14+
EagleDecodingConfig,
15+
KvCacheConfig,
16+
SamplerType,
17+
_ParallelConfig,
18+
)
1219
from .models import ModelFactory, ModelFactoryRegistry
1320
from .utils._config import DynamicYamlMixInForSettings
1421
from .utils.logger import ad_logger
@@ -38,6 +45,12 @@ def _check_for_default_value_only(
3845
return value
3946

4047

48+
def default_eagle3_layers_to_capture(num_hidden_layers: int) -> Set[int]:
49+
if num_hidden_layers <= 5:
50+
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
51+
return {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
52+
53+
4154
_TRANSFORMS_SHORTCUT_LOOKUP = {
4255
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
4356
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
@@ -150,6 +163,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
150163

151164
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
152165

166+
draft_checkpoint_loader: Optional[object] = Field(
167+
default=None,
168+
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
169+
)
170+
153171
### INFERENCE OPTIMIZER CONFIG #################################################################
154172
mode: Literal["graph", "transformers"] = Field(
155173
default="graph",
@@ -190,11 +208,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
190208
),
191209
)
192210

193-
draft_checkpoint_loader: Optional[object] = Field(
194-
default=None,
195-
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
196-
)
197-
198211
### SEQUENCE INTERFACE CONFIG ##################################################################
199212
max_input_len: int = Field(default=1024, description="The maximum input length.")
200213
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
@@ -401,6 +414,26 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A
401414
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
402415
return _check_for_default_value_only(cls, value, info, msg)
403416

417+
@model_validator(mode="after")
418+
def default_eagle3_layers_to_capture(self):
419+
if self.speculative_config is None or not isinstance(
420+
self.speculative_config, EagleDecodingConfig
421+
):
422+
return self
423+
424+
if self.speculative_config.eagle3_layers_to_capture is None:
425+
num_hidden_layers = self.create_factory()._get_model_config()[0].num_hidden_layers
426+
self.speculative_config.eagle3_layers_to_capture = default_eagle3_layers_to_capture(
427+
num_hidden_layers
428+
)
429+
430+
# insert the layers to capture into the transforms config.
431+
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
432+
self.speculative_config.eagle3_layers_to_capture
433+
)
434+
435+
return self
436+
404437
@model_validator(mode="after")
405438
def validate_parallel_config(self):
406439
"""Setup parallel config according to world_size.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 161 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
import copy
1313
from collections import defaultdict
1414
from dataclasses import dataclass
15-
from types import SimpleNamespace
15+
from types import MethodType, SimpleNamespace
1616
from typing import Dict, List, Optional, Tuple
1717

1818
import torch
19+
import torch.nn.functional as F
1920
from strenum import StrEnum
2021
from torch._prims_common import DeviceLikeType
2122

@@ -30,9 +31,11 @@
3031
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
3132
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
3233
from tensorrt_llm._torch.speculative import get_spec_drafter
34+
from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager
3335
from tensorrt_llm._utils import nvtx_range
3436
from tensorrt_llm.llmapi.llm_args import (
3537
ContextChunkingPolicy,
38+
EagleDecodingConfig,
3639
LoadFormat,
3740
SamplerType,
3841
SpeculativeConfig,
@@ -51,6 +54,7 @@
5154
from ...pyexecutor.scheduler import (
5255
BindCapacityScheduler,
5356
BindMicroBatchScheduler,
57+
RequestList,
5458
ScheduledRequests,
5559
SimpleScheduler,
5660
)
@@ -107,6 +111,90 @@ def calculate_max_num_blocks(
107111
return self.num_blocks, 0
108112

109113

114+
class ADHiddenStateManager(Eagle3ResourceManager):
115+
def __init__(
116+
self,
117+
cache_seq_interface: CachedSequenceInterface,
118+
config: EagleDecodingConfig,
119+
max_num_requests: int,
120+
max_seq_len: int,
121+
max_num_tokens: int,
122+
):
123+
hidden_state_buffer = self._get_hidden_state_buffers(cache_seq_interface)[0]
124+
dtype = hidden_state_buffer.dtype
125+
hidden_size = hidden_state_buffer.shape[1]
126+
127+
super().__init__(config, dtype, hidden_size, max_num_requests, max_seq_len, max_num_tokens)
128+
129+
self.hidden_state_write_indices: torch.Tensor = torch.empty(
130+
max_num_tokens, dtype=torch.long, device="cuda"
131+
)
132+
133+
def _get_hidden_state_buffers(
134+
self, cache_seq_interface: CachedSequenceInterface
135+
) -> List[torch.Tensor]:
136+
hidden_state_buffers = []
137+
for name, tensor in cache_seq_interface.named_args.items():
138+
if "hidden_states_cache" in name:
139+
hidden_state_buffers.append(tensor)
140+
141+
if not hidden_state_buffers:
142+
raise ValueError(
143+
"No hidden_state_buffers found in cache_seq_interface. Check if we are actually running Eagle3."
144+
)
145+
return hidden_state_buffers
146+
147+
def prepare_hidden_states_capture(
148+
self, ordered_requests: RequestList, cache_seq_interface: CachedSequenceInterface
149+
) -> None:
150+
"""Prepare the hidden states for capture by establishing indices that the hidden states will be written to."""
151+
seq_lens = cache_seq_interface.info.seq_len
152+
num_tokens = sum(seq_lens)
153+
154+
start_idx = 0
155+
hidden_states_write_indices = []
156+
for request, seq_len in zip(ordered_requests, seq_lens):
157+
request_id = request.request_id
158+
slot_id = self.slot_manager.get_slot(request_id)
159+
self.start_indices[slot_id] = start_idx
160+
hidden_states_write_indices.extend(range(start_idx, start_idx + seq_len))
161+
start_idx += max(seq_len, self.max_total_draft_tokens + 1)
162+
assert start_idx < self.hidden_states.shape[0], (
163+
f"start_idx {start_idx} exceeds hidden_states capacity {self.hidden_states.shape[0]}"
164+
)
165+
166+
if len(hidden_states_write_indices) != num_tokens:
167+
raise ValueError(
168+
f"len(hidden_state_write_indices) ({len(hidden_states_write_indices)}) != num_tokens \
169+
({num_tokens}). Check whether ordered_requests matches up with seq_lens."
170+
)
171+
172+
hidden_state_write_indices_host = torch.tensor(
173+
hidden_states_write_indices, dtype=torch.long
174+
)
175+
176+
self.hidden_state_write_indices[:num_tokens].copy_(
177+
hidden_state_write_indices_host, non_blocking=True
178+
)
179+
180+
def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) -> None:
181+
"""Capture configured hidden states that have been written by the model,
182+
in a format that can be used by the draft model.
183+
"""
184+
full_hidden_states = self._get_hidden_state_buffers(cache_seq_interface)
185+
if not full_hidden_states:
186+
return
187+
188+
num_tokens = sum(cache_seq_interface.info.seq_len)
189+
190+
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
191+
hidden_states = torch.cat(hidden_states, dim=1) if hidden_states else None
192+
hidden_states = hidden_states.to(dtype=self.dtype)
193+
194+
token_idx = self.hidden_state_write_indices[:num_tokens]
195+
self.hidden_states[:, : hidden_states.shape[1]].index_copy_(0, token_idx, hidden_states)
196+
197+
110198
def construct_draft_llm_args(
111199
ad_config: LlmArgs,
112200
) -> TorchLlmArgs:
@@ -331,6 +419,10 @@ def _prepare_inputs(
331419
kv_cache_manager = resource_manager.get_resource_manager(
332420
ResourceManagerType.KV_CACHE_MANAGER
333421
)
422+
# resource manager for hidden state capture
423+
spec_resource_manager = resource_manager.get_resource_manager(
424+
ResourceManagerType.SPEC_RESOURCE_MANAGER
425+
)
334426

335427
# requests in order of context, generate
336428
context_requests = scheduled_requests.context_requests
@@ -341,6 +433,7 @@ def _prepare_inputs(
341433
r for r in scheduled_requests.generation_requests if get_draft_token_length(r) == 0
342434
]
343435
gen_requests = extend_requests + generation_requests
436+
ordered_requests = context_requests + gen_requests
344437
# info to be extracted
345438
input_ids: List[List[int]] = []
346439
input_pos: List[int] = []
@@ -481,17 +574,32 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
481574
scatter_ref=dummy_token,
482575
)
483576

577+
if spec_resource_manager is not None and isinstance(
578+
spec_resource_manager, ADHiddenStateManager
579+
):
580+
spec_resource_manager.prepare_hidden_states_capture(
581+
ordered_requests, self.cache_seq_interface
582+
)
583+
484584
self.iter_states["num_ctx_requests"] = num_ctx_requests
485585
self.iter_states["num_ctx_tokens"] = num_ctx_tokens
486586
# TODO: handle extend requests and draft requests for specdec
487587
self.iter_states["num_generation_tokens"] = num_generation_tokens
488588
return last_logit_only
489589

490590
@nvtx_range("ad_compute_logits")
491-
def _compute_logits(self) -> List[torch.Tensor]:
591+
def _compute_logits(self, resource_manager: ResourceManager) -> List[torch.Tensor]:
492592
# run the model
493593
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
494594

595+
spec_resource_manager = resource_manager.get_resource_manager(
596+
ResourceManagerType.SPEC_RESOURCE_MANAGER
597+
)
598+
if spec_resource_manager is not None and isinstance(
599+
spec_resource_manager, ADHiddenStateManager
600+
):
601+
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
602+
495603
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
496604
logits = logits.float()
497605

@@ -518,7 +626,7 @@ def forward(
518626
self.iter_counter += 1
519627

520628
# compute all logits
521-
logits = self._compute_logits()
629+
logits = self._compute_logits(resource_manager)
522630

523631
# gather+cat logits
524632
logits_flat = torch.cat(
@@ -529,8 +637,30 @@ def forward(
529637
return {"logits": logits_flat}
530638

531639

640+
def share_embedding_weights(
641+
target_model_engine: "ADEngine", draft_model_engine: PyTorchModelEngine
642+
):
643+
# This function is necessary for supporting Eagle and other speculative decoding methods that
644+
# copy the embed_tokens submodule. It is not necessary for MTP and other speculative decoding methods that
645+
# use the draft model engine directly.
646+
647+
submodule = target_model_engine.model.model.embed_tokens
648+
649+
world_size = mpi_world_size()
650+
assert world_size <= 1, f"This code assumes tp<=1. World size: {world_size}"
651+
652+
# Note: This simple forward function implementation assumes tp=1.
653+
# TODO(govind): Handle the tp>1 case.
654+
def new_embedding_forward(self, input_ids):
655+
return F.embedding(input_ids, self.weight)
656+
657+
submodule.forward = MethodType(new_embedding_forward, submodule)
658+
659+
draft_model_engine.load_weights_from_target_model(target_model_engine.model)
660+
661+
532662
def create_draft_model_engine_maybe(
533-
ad_config: LlmArgs, engine, dist_mapping: Mapping, mpi_dist: MPIDist
663+
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
534664
) -> Optional[PyTorchModelEngine]:
535665
"""Create a draft model engine for speculative decoding.
536666
@@ -558,14 +688,18 @@ def create_draft_model_engine_maybe(
558688
chunked_prefill=ad_config.enable_chunked_prefill,
559689
cache_reuse=kv_cache_config.enable_block_reuse,
560690
has_speculative_draft_tokens=has_spec_drafter,
561-
chunk_size=engine.llm_args.max_num_tokens,
691+
chunk_size=target_engine.llm_args.max_num_tokens,
562692
)
563693

564694
# Construct TorchLlmArgs for the draft model
565695
draft_llm_args = construct_draft_llm_args(
566696
ad_config=ad_config,
567697
)
568698

699+
# chain drafter is not supported currently for AutoDeploy.
700+
# TODO(govind): Do this when we want to optimize 2-model spec dec performance.
701+
drafting_loop_wrapper = None
702+
569703
draft_model_engine = PyTorchModelEngine(
570704
model_path=draft_spec_config.speculative_model_dir,
571705
llm_args=draft_llm_args,
@@ -574,7 +708,11 @@ def create_draft_model_engine_maybe(
574708
dist=mpi_dist,
575709
spec_config=draft_spec_config,
576710
is_draft_model=True,
577-
drafting_loop_wrapper=None,
711+
drafting_loop_wrapper=drafting_loop_wrapper,
712+
)
713+
714+
share_embedding_weights(
715+
target_model_engine=target_engine, draft_model_engine=draft_model_engine
578716
)
579717

580718
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
@@ -668,21 +806,32 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
668806
engine = ADEngine.build_from_config(ad_config=ad_config)
669807

670808
spec_config = ad_config.speculative_config
671-
if spec_config is not None and not spec_config.spec_dec_mode.is_draft_target():
809+
if spec_config is not None and not (
810+
spec_config.spec_dec_mode.is_draft_target() or spec_config.spec_dec_mode.is_eagle3()
811+
):
672812
raise ValueError(
673-
"Currently, AutoDeploy only supports speculative decoding in draft target mode."
813+
"Currently, AutoDeploy only supports speculative decoding in draft target or eagle3 mode."
674814
)
675815

676816
if spec_config is not None and ad_config.guided_decoding_backend is not None:
677817
raise ValueError(
678818
"Guided decoding is not currently supported for speculative decoding in AutoDeploy."
679819
)
680820

681-
# Speculative resource manager not needed for DraftTargetDecoding.
682-
spec_resource_manager = None
683-
684821
draft_model_engine = create_draft_model_engine_maybe(
685-
ad_config=ad_config, engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
822+
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
823+
)
824+
825+
spec_resource_manager = (
826+
ADHiddenStateManager(
827+
cache_seq_interface=engine.cache_seq_interface,
828+
config=spec_config,
829+
max_num_requests=ad_config.max_batch_size,
830+
max_seq_len=engine.llm_args.max_seq_len,
831+
max_num_tokens=engine.llm_args.max_num_tokens,
832+
)
833+
if isinstance(spec_config, EagleDecodingConfig)
834+
else None
686835
)
687836

688837
# check kvcache config for partial block reuse

0 commit comments

Comments
 (0)