Skip to content

Commit 59e1a5d

Browse files
author
niushengxiao
committed
xfix: continue fix
1 parent 250b80c commit 59e1a5d

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def __init__(
2222
split_inter_size: int,
2323
data_type: torch.dtype,
2424
network_config: Dict[str, Any],
25-
weight_scale_suffix: Optional[str] = None,
26-
act_scale_suffix: Optional[str] = None,
25+
layer_num: int,
26+
quant_cfg = None,
2727
) -> None:
2828
super().__init__(
2929
gate_proj_name,
@@ -35,8 +35,8 @@ def __init__(
3535
split_inter_size,
3636
data_type,
3737
network_config,
38-
weight_scale_suffix,
39-
act_scale_suffix
38+
layer_num,
39+
quant_cfg,
4040
)
4141
self.expert_gate_up_proj_etp = None
4242
self.expert_down_proj_etp = None

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.common.quantization.quantize_method import QuantizationMethod
88
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
99
from lightllm.common.vllm_kernel import _custom_ops as ops
10+
from .mm_weight.mm_weight import MMWeight
1011

1112

1213
class FusedMoeWeightTP(BaseWeight):
@@ -21,17 +22,16 @@ def __init__(
2122
split_inter_size: int,
2223
data_type: torch.dtype,
2324
network_config: Dict[str, Any],
24-
weight_scale_suffix: Optional[str] = None,
25-
act_scale_suffix: Optional[str] = None,
25+
layer_num: int,
26+
quant_cfg = None,
2627
) -> None:
2728
super().__init__()
29+
self.quant_method, self.quantized_weight = MMWeight._get_quant_method(quant_cfg, layer_num, weight_prefix)
30+
if quant_cfg is not None and quant_cfg.quantized_weight:
31+
self.weight_scale_suffix = "weight_scale_inv"
2832
self.w1_weight_name = gate_proj_name
2933
self.w2_weight_name = down_proj_name
3034
self.w3_weight_name = up_proj_name
31-
self.weight_scale_suffix = weight_scale_suffix
32-
self.act_scale_suffix = act_scale_suffix
33-
self.quantized_weight = weight_scale_suffix is not None
34-
self.static_activation = act_scale_suffix is not None
3535

3636
self.e_score_correction_bias_name = e_score_correction_bias_name
3737
self.weight_prefix = weight_prefix
@@ -46,7 +46,6 @@ def __init__(
4646
self.e_score_correction_bias = None
4747
self.w2_list = [None] * self.n_routed_experts
4848
self.w2_scale_list = [None] * self.n_routed_experts
49-
self.quant_method = None
5049
self.scoring_func = network_config["scoring_func"]
5150
self.w1 = [None, None] # weight, weight_scale
5251
self.w2 = [None, None] # weight, weight_scale

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,17 @@ def __new__(cls, **kwargs):
164164
name = kwargs.pop("name", None)
165165
quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name)
166166
kwargs["quant_method"] = quant_method
167-
if quant_cfg.static_activation:
167+
if quant_cfg is not None and quant_cfg.static_activation:
168168
kwargs["act_scale_suffix"] = "input_scale"
169-
if quant_cfg.quantized_weight:
169+
if quant_cfg is not None and quant_cfg.quantized_weight:
170170
kwargs["weight_scale_suffix"] = "weight_scale_inv"
171171
mmcls = cls._get_mmcls(quant_method, quantized_weight)
172172
return mmcls(**kwargs)
173173

174174
@classmethod
175175
def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod:
176+
if quant_cfg is None:
177+
return None, False
176178
quant_method = quant_cfg.get_quant_method(layer_num_, name)
177179
quant_type = quant_cfg.get_quant_type(layer_num_, name)
178180
quantized_weight = quant_cfg.quantized_weight

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def _init_moe(self):
321321
self.moe_gate = ROWMMWeight(
322322
weight_name=f"model.layers.{self.layer_num_}.mlp.gate.weight",
323323
data_type=self.data_type_,
324-
quant_cfg=self.quant_cfg,
325324
layer_num=self.layer_num_,
326325
name="moe_gate",
327326
tp_rank=0,
@@ -342,8 +341,8 @@ def _init_moe(self):
342341
split_inter_size=moe_intermediate_size // self.tp_world_size_,
343342
data_type=self.data_type_,
344343
network_config=self.network_config_,
345-
weight_scale_suffix=self.weight_scale_suffix,
346-
act_scale_suffix=self.act_scale_suffix,
344+
layer_num=self.layer_num_,
345+
quant_cfg=self.quant_cfg,
347346
)
348347

349348
def _init_ffn(self):

0 commit comments

Comments
 (0)