Skip to content

Commit e02a362

Browse files
committed
fix quant
1 parent fc59674 commit e02a362

File tree

11 files changed

+81
-34
lines changed

11 files changed

+81
-34
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
1212

1313

14-
class UnquantizedCOLMMWeight(MMWeightTpl):
14+
class StandardCOLMMWeight(MMWeightTpl):
1515
def __init__(
1616
self,
1717
weight_names: Union[str, List[str]],
@@ -72,7 +72,9 @@ def __init__(
7272
tp_world_size=tp_world_size,
7373
)
7474
# 注意这里不是错误,因为awq的weight是按inxout存的
75-
self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
75+
self.param_slicer = QuantizedRowSliceMixin(
76+
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=True
77+
)
7678

7779

7880
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
BMMWeightTpl,
77
)
88
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import (
9-
UnquantizedROWMMWeight,
9+
StandardROWMMWeight,
1010
UnquantizedROWBMMWeight,
1111
ROWMM_WEIGHT_CLS_MAP,
1212
)
1313
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import (
14-
UnquantizedCOLMMWeight,
14+
StandardCOLMMWeight,
1515
COLMM_WEIGHT_CLS_MAP,
1616
)
1717

@@ -61,9 +61,12 @@ class ROWMMWeight(MMWeight):
6161
@classmethod
6262
def _get_mmcls(cls, quant_method: QuantizationMethod):
6363
if quant_method is None:
64-
return UnquantizedROWMMWeight
64+
return StandardROWMMWeight
6565

66-
return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]
66+
return ROWMM_WEIGHT_CLS_MAP.get(
67+
quant_method.method_name,
68+
StandardROWMMWeight,
69+
)
6770

6871

6972
class ROWBMMWeight(MMWeight):
@@ -80,5 +83,8 @@ class COLMMWeight(MMWeight):
8083
@classmethod
8184
def _get_mmcls(cls, quant_method: QuantizationMethod):
8285
if quant_method is None:
83-
return UnquantizedCOLMMWeight
84-
return COLMM_WEIGHT_CLS_MAP[quant_method.method_name]
86+
return StandardCOLMMWeight
87+
return COLMM_WEIGHT_CLS_MAP.get(
88+
quant_method.method_name,
89+
StandardCOLMMWeight,
90+
)

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
class SliceMixinBase(ABC):
88
"""切片操作的Mixin基类"""
99

10-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
10+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
1111
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
1212
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
13+
self.bias_div_world_size_ = bias_div_world_size
1314

1415
@abstractmethod
1516
def _slice_weight(self, weight: torch.Tensor):
@@ -21,8 +22,8 @@ def _slice_bias(self, bias):
2122

2223

2324
class SliceMixinTpl(SliceMixinBase):
24-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
25-
super().__init__(tp_rank, tp_world_size)
25+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
26+
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
2627

2728
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
2829
raise NotImplementedError("slice_weight must implement this method")
@@ -40,8 +41,8 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
4041
# 默认weight 的shape是 outxin,这也是目前最通用的约定。
4142
# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
4243
class RowSliceMixin(SliceMixinTpl):
43-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
44-
super().__init__(tp_rank, tp_world_size)
44+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
45+
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
4546

4647
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
4748
assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}"
@@ -51,14 +52,16 @@ def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
5152
def _slice_bias(self, bias) -> torch.Tensor:
5253
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
5354
tp_size = bias.shape[0] // self.tp_world_size_
55+
if self.bias_div_world_size_:
56+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
5457
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
5558

5659

5760
# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。
5861
# 后续按需要,扩展per-tensor、per-channel的量化方式。
5962
class QuantizedRowSliceMixin(RowSliceMixin):
60-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
61-
super().__init__(tp_rank, tp_world_size)
63+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
64+
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
6265

6366
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
6467
assert (
@@ -80,8 +83,8 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
8083

8184

8285
class ColSliceMixin(SliceMixinTpl):
83-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
84-
super().__init__(tp_rank, tp_world_size)
86+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
87+
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
8588

8689
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
8790
assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}"
@@ -91,12 +94,14 @@ def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
9194
def _slice_bias(self, bias) -> torch.Tensor:
9295
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
9396
tp_size = bias.shape[0] // self.tp_world_size_
97+
if self.bias_div_world_size_:
98+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
9499
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
95100

96101

97102
class QuantizedColSliceMixin(ColSliceMixin):
98-
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
99-
super().__init__(tp_rank, tp_world_size)
103+
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
104+
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
100105

101106
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
102107
assert (

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def __init__(
6868
quant_method: QuantizationMethod = None,
6969
tp_rank: int = None,
7070
tp_world_size: int = None,
71-
has_weight_scale: bool = False,
72-
has_weight_zero_point: bool = False,
7371
) -> None:
7472
super().__init__(tp_rank, tp_world_size, data_type)
7573
self.lock = threading.Lock()
@@ -84,13 +82,20 @@ def __init__(
8482
if bias_names[0] is None:
8583
bias_names = None
8684

85+
if quant_method is not None:
86+
has_weight_scale = quant_method.has_weight_scale
87+
has_weight_zero_point = quant_method.has_weight_zero_point
88+
else:
89+
has_weight_scale = False
90+
has_weight_zero_point = False
91+
8792
# 同时存在 weight_names 和 quanted_weight_names 是为了兼容在线和离线两种加载方案
8893
self.weight_names = weight_names
8994

9095
self.bias_names = bias_names
9196
has_bias = self.bias_names is not None
9297

93-
self.gen_weight_quant_param_names(quant_method=quant_method, has_weight_zero_point=has_weight_zero_point)
98+
self.gen_weight_quant_param_names(quant_method=quant_method)
9499
self.quant_method = quant_method
95100
self.sub_child_mm_params: List[MMWeightPack] = [
96101
MMWeightPack(
@@ -132,7 +137,7 @@ def mm(
132137
return torch.mm(input_tensor, self.mm_param.weight, out=out)
133138
return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out)
134139

135-
def gen_weight_quant_param_names(self, quant_method: QuantizationMethod, has_weight_zero_point: bool):
140+
def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]):
136141
if quant_method is None:
137142
self.quanted_weight_names = None
138143
self.weight_zero_point_names = None
@@ -144,11 +149,10 @@ def gen_weight_quant_param_names(self, quant_method: QuantizationMethod, has_wei
144149
weight_zero_point_names = []
145150

146151
for weight_name in self.weight_names:
147-
assert quant_method.weight_scale_suffix is not None, "weight_scale_suffix is not set"
148-
weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix)
149-
weight_scale_names.append(weight_scale_name)
150-
if has_weight_zero_point:
151-
assert quant_method.weight_zero_point_suffix is not None, "weight_zero_point_suffix is not set"
152+
if quant_method.weight_scale_suffix is not None:
153+
weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix)
154+
weight_scale_names.append(weight_scale_name)
155+
if quant_method.weight_zero_point_suffix is not None:
152156
weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix)
153157
weight_zero_point_names.append(weight_zero_point_name)
154158
if quant_method.weight_suffix is not None:
@@ -410,8 +414,6 @@ def __init__(
410414
quant_method=quant_method,
411415
tp_rank=tp_rank,
412416
tp_world_size=tp_world_size,
413-
has_weight_scale=True,
414-
has_weight_zero_point=False,
415417
)
416418

417419
def _to_gpu_device(self) -> None:
@@ -445,8 +447,6 @@ def __init__(
445447
quant_method=quant_method,
446448
tp_rank=tp_rank,
447449
tp_world_size=tp_world_size,
448-
has_weight_scale=True,
449-
has_weight_zero_point=True,
450450
)
451451
self.weight_fused_dim = 1
452452
self.bias_fused_dim = 0

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
1313

1414

15-
class UnquantizedROWMMWeight(MMWeightTpl):
15+
class StandardROWMMWeight(MMWeightTpl):
1616
def __init__(
1717
self,
1818
weight_names: Union[str, List[str]],
@@ -95,7 +95,9 @@ def __init__(
9595
tp_world_size=tp_world_size,
9696
)
9797
# 注意这里不是错误,因为awq的weight是按inxout存的
98-
self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
98+
self.param_slicer = QuantizedColSliceMixin(
99+
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False
100+
)
99101

100102

101103
class AWQMARLINROWMMWeight(AWQROWMMWeight):

lightllm/common/quantization/awq_quant.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(self):
6565
self.weight_scale_suffix = "scales"
6666
self.weight_zero_point_suffix = "qzeros"
6767
self.weight_suffix = "qweight"
68+
self.has_weight_scale = True
69+
self.has_weight_zero_point = True
6870

6971
@property
7072
def method_name(self):
@@ -111,6 +113,8 @@ def __init__(self):
111113
self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda"))
112114
self.workspace = marlin_make_workspace_new(torch.device("cuda"))
113115
self.vllm_quant_type = TYPE_MAP[self.nbits]
116+
self.has_weight_scale = True
117+
self.has_weight_zero_point = True
114118

115119
@property
116120
def method_name(self):

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self):
5353
self.weight_suffix = None
5454
self.weight_zero_point_suffix = None
5555
self.weight_scale_suffix = "weight_scale_inv"
56+
self.has_weight_scale = True
57+
self.has_weight_zero_point = False
5658

5759
@property
5860
def method_name(self):

lightllm/common/quantization/quantize_method.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def __init__(self):
1515
self.weight_scale_suffix = None
1616
self.weight_zero_point_suffix = None
1717
self.act_scale_suffix = None
18+
self.has_weight_scale: bool = (None,)
19+
self.has_weight_zero_point: bool = (None,)
1820
# 一些量化模式需要用到的额外量化参数,如awq量化
1921
self.hf_quantization_config = None
2022

lightllm/common/quantization/torchao_quant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(self):
6464
super().__init__()
6565
self.group_size = 256
6666
self.quant_func = int4_weight_only(group_size=self.group_size)
67+
self.has_weight_scale = False
68+
self.has_weight_zero_point = False
6769

6870
@property
6971
def method_name(self):
@@ -76,6 +78,8 @@ def __init__(self):
7678
super().__init__()
7779
self.group_size = 128
7880
self.quant_func = int4_weight_only(group_size=self.group_size)
81+
self.has_weight_scale = False
82+
self.has_weight_zero_point = False
7983

8084
@property
8185
def method_name(self):
@@ -88,6 +92,8 @@ def __init__(self):
8892
super().__init__()
8993
self.group_size = 64
9094
self.quant_func = int4_weight_only(group_size=self.group_size)
95+
self.has_weight_scale = False
96+
self.has_weight_zero_point = False
9197

9298
@property
9399
def method_name(self):
@@ -100,6 +106,8 @@ def __init__(self):
100106
super().__init__()
101107
self.group_size = 32
102108
self.quant_func = int4_weight_only(group_size=self.group_size)
109+
self.has_weight_scale = False
110+
self.has_weight_zero_point = False
103111

104112
@property
105113
def method_name(self):
@@ -111,6 +119,8 @@ class AOW8A8QuantizationMethod(AOBaseQuantizationMethod):
111119
def __init__(self):
112120
super().__init__()
113121
self.quant_func = int8_dynamic_activation_int8_weight()
122+
self.has_weight_scale = False
123+
self.has_weight_zero_point = False
114124

115125
@property
116126
def method_name(self):
@@ -122,6 +132,8 @@ class AOW8A16QuantizationMethod(AOBaseQuantizationMethod):
122132
def __init__(self):
123133
super().__init__()
124134
self.quant_func = int8_weight_only()
135+
self.has_weight_scale = False
136+
self.has_weight_zero_point = False
125137

126138
@property
127139
def method_name(self):
@@ -135,6 +147,8 @@ def __init__(self):
135147
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
136148
assert is_cuda_8_9, "FP8 requires GPU with compute capability >= 8.9"
137149
self.quant_func = float8_weight_only()
150+
self.has_weight_scale = False
151+
self.has_weight_zero_point = False
138152

139153
@property
140154
def method_name(self):
@@ -147,6 +161,8 @@ def __init__(self):
147161
super().__init__()
148162
assert TORCH_VERSION_AT_LEAST_2_5, "torchao fp6 requires torch >=2.5"
149163
self.quant_func = fpx_weight_only(3, 2)
164+
self.has_weight_scale = False
165+
self.has_weight_zero_point = False
150166

151167
@property
152168
def method_name(self):

lightllm/common/quantization/triton_quant/triton_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(self):
4141
self.weight_suffix = None
4242
self.weight_zero_point_suffix = None
4343
self.weight_scale_suffix = "weight_scale_inv"
44+
self.has_weight_scale = True
45+
self.has_weight_zero_point = False
4446

4547
def quantize(self, weight: torch.Tensor):
4648
# TODO block-wise quant kernel

0 commit comments

Comments
 (0)