Skip to content

Commit ca8007f

Browse files
authored
[Feature] Enable inference support for Deepseekr1-w8a8-MTP (vllm-project#1994)
Support the inference of the Deepseekr1-w8a8-mtp model with statically-quantized shared_head in MTP layers. - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@6eca337 Signed-off-by: curryliu <[email protected]>
1 parent 98cadc2 commit ca8007f

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2929
from vllm.model_executor.layers.quantization import QuantizationConfig
3030
from vllm.model_executor.layers.sampler import get_sampler
31-
from vllm.model_executor.layers.vocab_parallel_embedding import \
32-
VocabParallelEmbedding
31+
from vllm.model_executor.layers.vocab_parallel_embedding import (
32+
ParallelLMHead, VocabParallelEmbedding)
3333
from vllm.model_executor.models.deepseek_mtp import (
3434
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
3535
SharedHead)
@@ -40,6 +40,20 @@
4040
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
4141

4242

43+
class CustomDeepSeekShareHead(SharedHead):
44+
45+
def __init__(self,
46+
config: PretrainedConfig,
47+
quant_config: Optional[QuantizationConfig] = None,
48+
prefix: str = "") -> None:
49+
nn.Module.__init__(self)
50+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51+
self.head = ParallelLMHead(config.vocab_size,
52+
config.hidden_size,
53+
quant_config=quant_config,
54+
prefix=maybe_prefix(prefix, "head"))
55+
56+
4357
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
4458

4559
def __init__(
@@ -61,7 +75,10 @@ def __init__(
6175
self.eh_proj = nn.Linear(config.hidden_size * 2,
6276
config.hidden_size,
6377
bias=False)
64-
self.shared_head = SharedHead(config=config, quant_config=quant_config)
78+
self.shared_head = CustomDeepSeekShareHead(config=config,
79+
quant_config=quant_config,
80+
prefix=maybe_prefix(
81+
prefix, "shared_head"))
6582
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
6683
model_config,
6784
cache_config,

vllm_ascend/models/deepseek_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
868868
if get_pp_group().is_last_rank:
869869
self.lm_head = ParallelLMHead(config.vocab_size,
870870
config.hidden_size,
871-
quant_config=quant_config)
871+
quant_config=quant_config,
872+
prefix=maybe_prefix(
873+
prefix, "lm_head"))
872874
else:
873875
self.lm_head = PPMissingLayer()
874876
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm_ascend/quantization/quant_config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from vllm.model_executor.layers.quantization.base_config import (
3535
QuantizationConfig, QuantizeMethodBase)
3636
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
37+
from vllm.model_executor.layers.vocab_parallel_embedding import (
38+
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
3739
from vllm.model_executor.parameter import PerTensorScaleParameter
3840
from vllm.model_executor.utils import set_weight_attrs
3941

@@ -107,6 +109,12 @@ def get_quant_method(self, layer: torch.nn.Module,
107109
return AscendUnquantizedFusedMoEMethod()
108110
return AscendFusedMoEMethod(self, prefix,
109111
self.packed_modules_mapping)
112+
elif isinstance(layer, VocabParallelEmbedding):
113+
if self.is_layer_skipped_ascend(prefix,
114+
self.packed_modules_mapping):
115+
return UnquantizedEmbeddingMethod()
116+
return AscendEmbeddingMethod(self, prefix,
117+
self.packed_modules_mapping)
110118
return None
111119

112120
def is_layer_skipped_ascend(
@@ -319,3 +327,18 @@ def apply(
319327
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
320328
if hasattr(self.quant_method, "process_weights_after_loading"):
321329
self.quant_method.process_weights_after_loading(layer)
330+
331+
332+
class AscendEmbeddingMethod(AscendLinearMethod):
333+
"""Embedding method for Ascend quantization.
334+
This class calls AscendQuantizer to search a specific quantization
335+
implementations supported on ascend hardware for Embedding methods.
336+
Args:
337+
quant_config: The Ascend quantization config.
338+
"""
339+
340+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
341+
packed_modules_mapping: Dict[str, Any]) -> None:
342+
self.quantizer = AscendQuantizer.get_quantizer(
343+
quant_config.quant_description, prefix, packed_modules_mapping)
344+
self.quant_method = self.quantizer.build_linear_method()

0 commit comments

Comments
 (0)