Skip to content

Commit 99dfd53

Browse files
author
wangzaijun
committed
improve weight impl
1 parent cf5d7fc commit 99dfd53

File tree

6 files changed

+361
-592
lines changed

6 files changed

+361
-592
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from multiprocessing import parent_process
21
import torch
32
from abc import ABC, abstractmethod
43
from typing import Dict
@@ -14,8 +13,8 @@ def load_hf_weights(self, weights):
1413
pass
1514

1615
@abstractmethod
17-
def verify_load(self):
18-
parent_process
16+
def verify_load(self) -> bool:
17+
pass
1918

2019

2120
class BaseWeightTpl(BaseWeight):

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
import torch
22
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
3-
SingleMMWeightTpl,
3+
MMWeightTpl,
44
DeepGemmFP8W8A8B128MMWeight,
55
AWQMMWeightTpl,
66
)
77
from lightllm.common.quantization import Quantcfg
88
from lightllm.utils.dist_utils import get_current_device_id
99
from lightllm.common.quantization.quantize_method import QuantizationMethod
10-
from typing import Dict, List, Optional
10+
from typing import Dict, List, Optional, Union
1111
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
1212

1313

14-
class UnquantizedCOLMMWeight(SingleMMWeightTpl):
14+
class UnquantizedCOLMMWeight(MMWeightTpl):
1515
def __init__(
1616
self,
17-
weight_name: str,
17+
weight_names: Union[str, List[str]],
1818
data_type: torch.dtype,
19-
bias_name: Optional[str] = None,
19+
bias_names: Optional[Union[str, List[str]]] = None,
2020
quant_method: QuantizationMethod = None,
2121
tp_rank: int = None,
2222
tp_world_size: int = None,
2323
) -> None:
2424
super().__init__(
25-
weight_name=weight_name,
25+
weight_name=weight_names,
2626
data_type=data_type,
27-
bias_name=bias_name,
27+
bias_name=bias_names,
2828
quant_method=quant_method,
2929
tp_rank=tp_rank,
3030
tp_world_size=tp_world_size,
@@ -35,17 +35,17 @@ def __init__(
3535
class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight):
3636
def __init__(
3737
self,
38-
weight_name: str,
38+
weight_names: Union[str, List[str]],
3939
data_type: torch.dtype,
40-
bias_name: Optional[str] = None,
40+
bias_names: Optional[Union[str, List[str]]] = None,
4141
quant_method: QuantizationMethod = None,
4242
tp_rank: int = None,
4343
tp_world_size: int = None,
4444
) -> None:
4545
super().__init__(
46-
weight_name=weight_name,
46+
weight_names=weight_names,
4747
data_type=data_type,
48-
bias_name=bias_name,
48+
bias_names=bias_names,
4949
quant_method=quant_method,
5050
tp_rank=tp_rank,
5151
tp_world_size=tp_world_size,
@@ -56,17 +56,17 @@ def __init__(
5656
class AWQCOLMMWeight(AWQMMWeightTpl):
5757
def __init__(
5858
self,
59-
weight_name: str,
59+
weight_names: Union[str, List[str]],
6060
data_type: torch.dtype,
61-
bias_name: Optional[str] = None,
61+
bias_names: Optional[Union[str, List[str]]] = None,
6262
quant_method: QuantizationMethod = None,
6363
tp_rank: int = None,
6464
tp_world_size: int = None,
6565
) -> None:
6666
super().__init__(
67-
weight_name=weight_name,
67+
weight_names=weight_names,
6868
data_type=data_type,
69-
bias_name=bias_name,
69+
bias_names=bias_names,
7070
quant_method=quant_method,
7171
tp_rank=tp_rank,
7272
tp_world_size=tp_world_size,
@@ -78,41 +78,22 @@ def __init__(
7878
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
7979
def __init__(
8080
self,
81-
weight_name: str,
81+
weight_names: Union[str, List[str]],
8282
data_type: torch.dtype,
83-
bias_name: Optional[str] = None,
83+
bias_names: Optional[Union[str, List[str]]] = None,
8484
quant_method: QuantizationMethod = None,
8585
tp_rank: int = None,
8686
tp_world_size: int = None,
8787
) -> None:
8888
super().__init__(
89-
weight_name=weight_name,
89+
weight_names=weight_names,
9090
data_type=data_type,
91-
bias_name=bias_name,
91+
bias_names=bias_names,
9292
quant_method=quant_method,
9393
tp_rank=tp_rank,
9494
tp_world_size=tp_world_size,
9595
)
9696

97-
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
98-
new_weight = self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
99-
self.mm_param.weight = new_weight
100-
return
101-
102-
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
103-
new_weight_scale = self.quant_method._process_weight_scale_after_loading(
104-
weight_scale.cuda(get_current_device_id()).to(self.data_type_)
105-
)
106-
self.mm_param.weight_scale = new_weight_scale
107-
return
108-
109-
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
110-
new_weight_zero_point = self.quant_method._process_weight_zero_point_after_loading(
111-
weight_zero_point.cuda(get_current_device_id())
112-
)
113-
self.mm_param.weight_zero_point = new_weight_zero_point
114-
return
115-
11697

11798
COLMM_WEIGHT_CLS_MAP = {
11899
"deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight,

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
from typing import Type, Union, Dict
44
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
55
MMWeightTpl,
6-
MultiMMWeightTpl,
76
BMMWeightTpl,
87
)
98
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import (
109
UnquantizedROWMMWeight,
1110
UnquantizedROWBMMWeight,
12-
UnquantizedMultiROWMMWeight,
1311
ROWMM_WEIGHT_CLS_MAP,
14-
MULTI_ROWMM_WEIGHT_CLS_MAP,
1512
)
1613
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import (
1714
UnquantizedCOLMMWeight,
@@ -21,54 +18,58 @@
2118

2219
class MMWeight:
2320
def __new__(cls, **kwargs):
21+
"""
22+
weight_names,
23+
data_type,
24+
bias_names,
25+
quant_cfg,
26+
layer_num,
27+
name,
28+
tp_rank,
29+
tp_world_size,
30+
...
31+
该类主要是通过重载 __new__ 为对应的mm权重绑定量化方法,其他参数都是透传。
32+
"""
33+
2434
quant_cfg = kwargs.pop("quant_cfg", None)
2535
layer_num_ = kwargs.pop("layer_num", None)
2636
name = kwargs.pop("name", None)
2737
quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name)
38+
# quantized_weight 本身是用来标识权重本身在文件中是否是以量化后的形式存储,
39+
# 现在不再使用该参数,是否量化由后续的加载过程自动识别。
2840
kwargs["quant_method"] = quant_method
29-
mmcls = cls._get_mmcls(quant_method, quantized_weight)
41+
mmcls = cls._get_mmcls(quant_method)
3042
return mmcls(**kwargs)
3143

3244
@classmethod
3345
def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod:
3446
if quant_cfg is None:
3547
return None, False
36-
quant_method = quant_cfg.get_quant_method(layer_num_, name)
48+
quant_method: QuantizationMethod = quant_cfg.get_quant_method(layer_num_, name)
3749
if quant_method is None:
3850
return None, False
3951
quant_method.hf_quantization_config = quant_cfg.hf_quantization_config
4052
quantized_weight = quant_cfg.quantized_weight
4153
return quant_method, quantized_weight
4254

4355
@classmethod
44-
def _get_mmcls(
45-
cls, quant_method: QuantizationMethod, quantized_weight: bool
46-
) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]:
56+
def _get_mmcls(cls, quant_method: QuantizationMethod) -> Type[Union[MMWeightTpl, BMMWeightTpl]]:
4757
raise NotImplementedError("Subclasses must implement _get_mmcls method")
4858

4959

5060
class ROWMMWeight(MMWeight):
5161
@classmethod
52-
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
53-
if quant_method is None or not quantized_weight:
62+
def _get_mmcls(cls, quant_method: QuantizationMethod):
63+
if quant_method is None:
5464
return UnquantizedROWMMWeight
5565

5666
return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]
5767

5868

59-
class MultiROWMMWeight(MMWeight):
60-
@classmethod
61-
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
62-
if quant_method is None or not quantized_weight:
63-
return UnquantizedMultiROWMMWeight
64-
65-
return MULTI_ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]
66-
67-
6869
class ROWBMMWeight(MMWeight):
6970
@classmethod
70-
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
71-
if quant_method is None or not quantized_weight:
71+
def _get_mmcls(cls, quant_method: QuantizationMethod):
72+
if quant_method is None:
7273
return UnquantizedROWBMMWeight
7374
else:
7475
# TODO: Implement more quantization weight
@@ -77,7 +78,7 @@ def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
7778

7879
class COLMMWeight(MMWeight):
7980
@classmethod
80-
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
81-
if quant_method is None or not quantized_weight:
81+
def _get_mmcls(cls, quant_method: QuantizationMethod):
82+
if quant_method is None:
8283
return UnquantizedCOLMMWeight
8384
return COLMM_WEIGHT_CLS_MAP[quant_method.method_name]

0 commit comments

Comments
 (0)