Skip to content

Commit 9898953

Browse files
small docstring changes and getting rid of unused functions/variables
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent c7403ce commit 9898953

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
"""Eagle3 model implementation for AutoDeploy.
1717
18-
Eagle is a speculative decoding draft model that predicts next tokens based on
18+
Eagle3 is a speculative decoding draft model that predicts next tokens based on
1919
hidden states from a target model (e.g., Llama-3.1-8B-Instruct).
2020
2121
This implementation:
22-
- Defines EagleConfig extending LlamaConfig with model_type="eagle3"
23-
- Wraps EagleModel in a HuggingFace-compatible interface
22+
- Defines Eagle3Config with model_type="eagle3"
23+
- Wraps Eagle3Model in a HuggingFace-compatible interface
2424
- Registers with AutoDeploy's custom model mechanism
2525
26-
Note: Eagle uses the same tokenizer as its target model (Llama), so when using
27-
this model, you must explicitly specify the tokenizer path pointing to the
26+
Note: Eagle3 uses the same tokenizer as its target model (e.g., Llama), so when
27+
using this model, you must explicitly specify the tokenizer path pointing to the
2828
target model.
2929
"""
3030

@@ -317,13 +317,6 @@ class Eagle3ModelForCausalLM(PreTrainedModel):
317317
def __init__(self, config):
318318
super().__init__(config)
319319
self.model = Eagle3Model(config)
320-
self._hidden_size = config.hidden_size
321-
self._dtype = config.dtype
322-
323-
def _init_weights(self, module):
324-
"""Initialize weights - called by PreTrainedModel."""
325-
# Default initialization - weights will be loaded from checkpoint
326-
pass
327320

328321
def forward(
329322
self,
@@ -419,6 +412,11 @@ class MockEagle3ModelForCausalLM(Eagle3ModelForCausalLM):
419412

420413
config_class = MockEagle3Config
421414

415+
def __init__(self, config):
416+
super().__init__(config)
417+
self._hidden_size = config.hidden_size
418+
self._dtype = config.dtype
419+
422420
def forward(self, input_ids, **kwargs):
423421
# Inject mock hidden states if not provided
424422
if "target_hidden_states" not in kwargs:

0 commit comments

Comments
 (0)