|
15 | 15 |
|
16 | 16 | """Eagle3 model implementation for AutoDeploy. |
17 | 17 |
|
18 | | -Eagle is a speculative decoding draft model that predicts next tokens based on |
| 18 | +Eagle3 is a speculative decoding draft model that predicts next tokens based on |
19 | 19 | hidden states from a target model (e.g., Llama-3.1-8B-Instruct). |
20 | 20 |
|
21 | 21 | This implementation: |
22 | | -- Defines EagleConfig extending LlamaConfig with model_type="eagle3" |
23 | | -- Wraps EagleModel in a HuggingFace-compatible interface |
| 22 | +- Defines Eagle3Config with model_type="eagle3" |
| 23 | +- Wraps Eagle3Model in a HuggingFace-compatible interface |
24 | 24 | - Registers with AutoDeploy's custom model mechanism |
25 | 25 |
|
26 | | -Note: Eagle uses the same tokenizer as its target model (Llama), so when using |
27 | | -this model, you must explicitly specify the tokenizer path pointing to the |
| 26 | +Note: Eagle3 uses the same tokenizer as its target model (e.g., Llama), so when |
| 27 | +using this model, you must explicitly specify the tokenizer path pointing to the |
28 | 28 | target model. |
29 | 29 | """ |
30 | 30 |
|
@@ -317,13 +317,6 @@ class Eagle3ModelForCausalLM(PreTrainedModel): |
317 | 317 | def __init__(self, config): |
318 | 318 | super().__init__(config) |
319 | 319 | self.model = Eagle3Model(config) |
320 | | - self._hidden_size = config.hidden_size |
321 | | - self._dtype = config.dtype |
322 | | - |
323 | | - def _init_weights(self, module): |
324 | | - """Initialize weights - called by PreTrainedModel.""" |
325 | | - # Default initialization - weights will be loaded from checkpoint |
326 | | - pass |
327 | 320 |
|
328 | 321 | def forward( |
329 | 322 | self, |
@@ -419,6 +412,11 @@ class MockEagle3ModelForCausalLM(Eagle3ModelForCausalLM): |
419 | 412 |
|
420 | 413 | config_class = MockEagle3Config |
421 | 414 |
|
| 415 | + def __init__(self, config): |
| 416 | + super().__init__(config) |
| 417 | + self._hidden_size = config.hidden_size |
| 418 | + self._dtype = config.dtype |
| 419 | + |
422 | 420 | def forward(self, input_ids, **kwargs): |
423 | 421 | # Inject mock hidden states if not provided |
424 | 422 | if "target_hidden_states" not in kwargs: |
|
0 commit comments