13
13
from vllm .model_executor .layers .layernorm import RMSNorm
14
14
from vllm .model_executor .layers .linear import QKVParallelLinear
15
15
from vllm .model_executor .layers .logits_processor import LogitsProcessor
16
+ from vllm .model_executor .layers .quantization .base_config import (
17
+ QuantizationConfig )
16
18
from vllm .model_executor .layers .vocab_parallel_embedding import (
17
19
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
18
20
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
@@ -33,7 +35,7 @@ def __init__(self,
33
35
super ().__init__ (vllm_config , prefix = prefix , config = config )
34
36
35
37
config = config or vllm_config .model_config .hf_config
36
- quant_config = vllm_config . quant_config
38
+ quant_config = self . get_quant_config ( vllm_config )
37
39
38
40
# override qkv
39
41
self .self_attn .qkv_proj = QKVParallelLinear (
@@ -53,6 +55,16 @@ def __init__(self,
53
55
else :
54
56
self ._residual_norm = self ._norm_after_residual
55
57
58
+ def get_quant_config (
59
+ self , vllm_config : VllmConfig ) -> Optional [QuantizationConfig ]:
60
+ """Use drafter's quantization config instead of verifier's."""
61
+ draft_model_config = vllm_config .speculative_config .draft_model_config
62
+ draft_load_config = vllm_config .load_config
63
+
64
+ return VllmConfig .get_quantization_config (
65
+ draft_model_config ,
66
+ draft_load_config ) if draft_model_config else None
67
+
56
68
def _norm_before_residual (
57
69
self ,
58
70
hidden_states : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
0 commit comments