Skip to content

Commit 429b650

Browse files
committed
Fix Eagle3 detection to check draft_vocab_size attribute
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
1 parent 98fbaac commit 429b650

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tensorrt_llm/_torch/models/modeling_auto.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@ def from_config(
2424
vision_encoder_cls, vlm_base_model = vision_encoder_info
2525
return vision_encoder_cls(config, vlm_base_model)
2626

27-
if "Eagle3" in model_arch:
28-
# Below is a hack to detect eagle3 checkpoints.
29-
# Why it exists:
30-
# - Community checkpoints append "Eagle3" to architecture names ("LlamaForCausalLMEagle3").
31-
# - Even NVIDIA official checkpoints (nvidia/Llama-4-Maverick-17B-128E-Eagle3) use the appended convention.
32-
# - But TensorRT-LLM's MODEL_CLASS_MAPPING expects prefixed names like EAGLE3LlamaForCausalLM
33-
# - Hence, LlamaForCausalLMEagle3 -> EAGLE3LlamaForCausalLM.
34-
# TODO: should we provide our own checkpoints with the correct arch? It would let us avoid nasty stuff like this.
27+
# Hack to detect eagle3 checkpoints.
28+
# Why it exists:
29+
# - Eagle3 checkpoints have draft_vocab_size in config.json (even if None)
30+
# - Some community checkpoints append "Eagle3" to architecture names ("LlamaForCausalLMEagle3")
31+
# - Some checkpoints don't include "Eagle3" in arch name at all ("LlamaForCausalLM")
32+
# - TensorRT-LLM's MODEL_CLASS_MAPPING expects prefixed names like EAGLE3LlamaForCausalLM
33+
# - Hence: LlamaForCausalLMEagle3 -> EAGLE3LlamaForCausalLM
34+
# LlamaForCausalLM (with draft_vocab_size) -> EAGLE3LlamaForCausalLM
35+
# TODO: should we provide our own checkpoints with the correct arch? It would let us avoid nasty stuff like this.
36+
if hasattr(config.pretrained_config, "draft_vocab_size"):
37+
# It's an Eagle3 checkpoint - strip "Eagle3" suffix if present, then add prefix
3538
model_arch = model_arch.replace("Eagle3", "")
3639
model_arch = "EAGLE3" + model_arch
3740
if model_arch in (

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,6 @@ def __init__(
279279
False)
280280
self._use_mla = use_mla
281281

282-
_ensure_draft_vocab_size(config)
283-
284282
if hasattr(config, "target_hidden_size"):
285283
self.hidden_size_in = config.target_hidden_size
286284
else:

0 commit comments

Comments
 (0)