Skip to content

Commit 145ac73

Browse files
authored
[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (vllm-project#25883)
Signed-off-by: Rahul Tuli <[email protected]>
1 parent d0d138b commit 145ac73

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
pytest.param(
1515
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
1616
id="qwen3-eagle3-speculator"),
17+
pytest.param(
18+
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
19+
id="qwen3-eagle3-speculator-w4a16-verifier"),
1720
])
1821
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
1922
monkeypatch):

vllm/model_executor/models/llama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(self,
248248

249249
config = config or vllm_config.model_config.hf_config
250250
cache_config = vllm_config.cache_config
251-
quant_config = vllm_config.quant_config
251+
quant_config = self.get_quant_config(vllm_config)
252252

253253
self.hidden_size = config.hidden_size
254254
rope_theta = getattr(config, "rope_theta", 10000)
@@ -328,6 +328,11 @@ def forward(
328328
hidden_states = self.mlp(hidden_states)
329329
return hidden_states, residual
330330

331+
def get_quant_config(
332+
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
333+
"""Get quantization config for this layer. Override in subclasses."""
334+
return vllm_config.quant_config
335+
331336

332337
@support_torch_compile
333338
class LlamaModel(nn.Module):

vllm/model_executor/models/llama_eagle3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from vllm.model_executor.layers.layernorm import RMSNorm
1414
from vllm.model_executor.layers.linear import QKVParallelLinear
1515
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.quantization.base_config import (
17+
QuantizationConfig)
1618
from vllm.model_executor.layers.vocab_parallel_embedding import (
1719
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
1820
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -33,7 +35,7 @@ def __init__(self,
3335
super().__init__(vllm_config, prefix=prefix, config=config)
3436

3537
config = config or vllm_config.model_config.hf_config
36-
quant_config = vllm_config.quant_config
38+
quant_config = self.get_quant_config(vllm_config)
3739

3840
# override qkv
3941
self.self_attn.qkv_proj = QKVParallelLinear(
@@ -53,6 +55,16 @@ def __init__(self,
5355
else:
5456
self._residual_norm = self._norm_after_residual
5557

58+
def get_quant_config(
59+
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
60+
"""Use drafter's quantization config instead of verifier's."""
61+
draft_model_config = vllm_config.speculative_config.draft_model_config
62+
draft_load_config = vllm_config.load_config
63+
64+
return VllmConfig.get_quantization_config(
65+
draft_model_config,
66+
draft_load_config) if draft_model_config else None
67+
5668
def _norm_before_residual(
5769
self,
5870
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)