Skip to content

Commit 358821e

Browse files
committed
add awq marlin
1 parent 18081a6 commit 358821e

File tree

5 files changed

+217
-7
lines changed

5 files changed

+217
-7
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,32 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor):
141141
return weight_zero_point[zero_point_start:zero_point_end, :]
142142

143143

144+
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
145+
def __init__(
146+
self,
147+
weight_name: str,
148+
data_type: torch.dtype,
149+
bias_name: Optional[str] = None,
150+
quant_method: QuantizationMethod = None,
151+
tp_rank: int = None,
152+
tp_world_size: int = None,
153+
) -> None:
154+
super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size)
155+
156+
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
157+
return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
158+
159+
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
160+
return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id()))
161+
162+
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
163+
return self.quant_method._process_weight_zero_point_after_loading(
164+
weight_zero_point.cuda(get_current_device_id())
165+
)
166+
167+
144168
COLBMM_WEIGHT_CLS_MAP = {
145169
"fp8w8a8b128": W8A8B128COLMMWeight,
146170
"awq": AWQCOLMMWeight,
171+
"awq_marlin": AWQMARLINCOLMMWeight,
147172
}

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,20 @@ def verify_load(self) -> bool:
184184
load_ok = load_ok and self.bias is not None
185185
return load_ok
186186

187+
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
188+
return weight.cuda(get_current_device_id())
189+
190+
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
191+
return weight_scale.cuda(get_current_device_id())
192+
193+
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
194+
return weight_zero_point.cuda(get_current_device_id())
195+
187196
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
188197
if self.weight_name is not None and self.weight_name in weights:
189198
weight = weights[self.weight_name]
190199
weight = self._slice_weight(weight)
191-
self.weight[0] = weight.cuda(get_current_device_id())
200+
self.weight[0] = self._process_weight(weight)
192201
if self.bias_name is not None and self.bias_name in weights:
193202
bias = weights[self.bias_name]
194203
bias = self._slice_bias(bias)
@@ -198,13 +207,13 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
198207
if self.weight_scale_name is not None and self.weight_scale_name in weights:
199208
weight_scale = weights[self.weight_scale_name]
200209
weight_scale = self._slice_weight_scale(weight_scale)
201-
self.weight[1] = weight_scale.cuda(get_current_device_id())
210+
self.weight[1] = self._process_weight_scale(weight_scale)
202211

203212
def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None:
204213
if self.weight_zero_point_name is not None and self.weight_zero_point_name in weights:
205214
weight_zero_point = weights[self.weight_zero_point_name]
206215
weight_zero_point = self._slice_weight_zero_point(weight_zero_point)
207-
self.weight[2] = weight_zero_point.cuda(get_current_device_id())
216+
self.weight[2] = self._process_weight_zero_point(weight_zero_point)
208217

209218

210219
class AWQMultiMMWeightTpl(AWQMMWeightTpl):
@@ -239,18 +248,18 @@ def __init__(
239248
def _fuse(self) -> None:
240249
if self.weight[0] is None and (None not in self.weights):
241250
weight = torch.cat(self.weights, dim=1)
242-
self.weight[0] = weight.cuda(get_current_device_id())
251+
self.weight[0] = self._process_weight(weight)
243252
delattr(self, "weights")
244253

245254
if self.weight[1] is None and (None not in self.weight_scales):
246255
# awq 保存的量化参数,weight shape 是 in x out。所以这里的cat dim 是 1
247256
weight_scale = torch.cat(self.weight_scales, dim=1).cuda(get_current_device_id())
248-
self.weight[1] = weight_scale.cuda(get_current_device_id())
257+
self.weight[1] = self._process_weight_scale(weight_scale)
249258
delattr(self, "weight_scales")
250259

251260
if self.weight[2] is None and (None not in self.weight_zero_points):
252261
weight_zero_point = torch.cat(self.weight_zero_points, dim=1)
253-
self.weight[2] = weight_zero_point.cuda(get_current_device_id())
262+
self.weight[2] = self._process_weight_zero_point(weight_zero_point)
254263
delattr(self, "weight_zero_points")
255264

256265
if self.has_bias and self.bias is None and (None not in self.biases):
@@ -300,7 +309,7 @@ def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> Q
300309
if quant_cfg is None:
301310
return None, False
302311
quant_method = quant_cfg.get_quant_method(layer_num_, name)
303-
quant_method.hf_quantization_method = quant_cfg.hf_quantization_method
312+
quant_method.hf_quantization_config = quant_cfg.hf_quantization_config
304313
quantized_weight = quant_cfg.quantized_weight
305314
return quant_method, quantized_weight
306315

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,62 @@ def __init__(
316316
super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size)
317317

318318

319+
class AWQMARLINROWMMWeight(AWQROWMMWeight):
320+
def __init__(
321+
self,
322+
weight_name: str,
323+
data_type: torch.dtype,
324+
bias_name: Optional[str] = None,
325+
quant_method: QuantizationMethod = None,
326+
tp_rank: int = None,
327+
tp_world_size: int = None,
328+
) -> None:
329+
super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size)
330+
331+
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
332+
return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
333+
334+
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
335+
return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id()))
336+
337+
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
338+
return self.quant_method._process_weight_zero_point_after_loading(
339+
weight_zero_point.cuda(get_current_device_id())
340+
)
341+
342+
343+
class AWQMARLINMultiROWMMWeight(AWQMultiROWMMWeight):
344+
def __init__(
345+
self,
346+
weight_names: List[str],
347+
data_type: torch.dtype,
348+
bias_names: Optional[List[str]] = None,
349+
quant_method: QuantizationMethod = None,
350+
tp_rank: int = None,
351+
tp_world_size: int = None,
352+
) -> None:
353+
super().__init__(weight_names, data_type, bias_names, quant_method, tp_rank, tp_world_size)
354+
355+
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
356+
return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
357+
358+
def _process_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
359+
return self.quant_method._process_weight_scale_after_loading(weight_scale.cuda(get_current_device_id()))
360+
361+
def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
362+
return self.quant_method._process_weight_zero_point_after_loading(
363+
weight_zero_point.cuda(get_current_device_id())
364+
)
365+
366+
319367
ROWBMM_WEIGHT_CLS_MAP = {
320368
"fp8w8a8b128": W8A8B128ROWMMWeight,
321369
"awq": AWQROWMMWeight,
370+
"awq_marlin": AWQMARLINROWMMWeight,
322371
}
323372

324373
MULTI_ROWBMM_WEIGHT_CLS_MAP = {
325374
"fp8w8a8b128": W8A8B128MultiROWMMWeight,
326375
"awq": AWQMultiROWMMWeight,
376+
"awq_marlin": AWQMARLINMultiROWMMWeight,
327377
}

lightllm/common/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def _mapping_quant_method(self):
4646
logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}")
4747
elif self.hf_quantization_method == "awq":
4848
self.quant_type = "awq"
49+
if is_awq_marlin_compatible(self.hf_quantization_config):
50+
self.quant_type = "awq_marlin"
4951
logger.info(f"select awq quant way: {self.quant_type}")
5052
else:
5153
# TODO: more quant method

lightllm/common/quantization/awq_quant.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,25 @@
55
import torch.nn.functional as F
66
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm
77
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
8+
from typing import Any
89

910
if HAS_VLLM:
1011
awq_dequantize = vllm_ops.awq_dequantize
1112
awq_gemm = vllm_ops.awq_gemm
13+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
14+
check_marlin_supported,
15+
marlin_permute_scales,
16+
awq_to_marlin_zero_points,
17+
should_use_atomic_add_reduce,
18+
marlin_make_empty_g_idx,
19+
marlin_make_workspace_new,
20+
)
21+
from vllm.scalar_type import scalar_types
22+
23+
TYPE_MAP = {
24+
4: scalar_types.uint4,
25+
8: scalar_types.uint8,
26+
}
1227

1328

1429
class AWQBaseQuantizationMethod(QuantizationMethod):
@@ -56,3 +71,112 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5671
if bias is not None:
5772
out.add_(bias)
5873
return out
74+
75+
76+
@QUANTMETHODS.register("awq_marlin")
77+
class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod):
78+
def __init__(self):
79+
super().__init__()
80+
self.pack_factor = 8
81+
self.weight_scale_suffix = "scales"
82+
self.weight_zero_point_suffix = "qzeros"
83+
self.weight_suffix = "qweight"
84+
self.g_idx = marlin_make_empty_g_idx(torch.device("cuda"))
85+
self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda"))
86+
self.workspace = marlin_make_workspace_new(torch.device("cuda"))
87+
88+
def get_name(self):
89+
return "awq_marlin"
90+
91+
def quantize(self, weight: torch.Tensor):
92+
raise NotImplementedError("AWQ online quantization is not supported yet.")
93+
94+
def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor:
95+
assert self.hf_quantization_config is not None, "hf_quantization_config is not set"
96+
self.k = weight.shape[0]
97+
self.n = weight.shape[1] * self.pack_factor
98+
return vllm_ops.awq_marlin_repack(
99+
weight,
100+
size_k=weight.shape[0],
101+
size_n=weight.shape[1] * self.pack_factor,
102+
num_bits=self.hf_quantization_config["bits"],
103+
)
104+
105+
def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor:
106+
assert self.hf_quantization_config is not None, "hf_quantization_config is not set"
107+
group_size = self.hf_quantization_config["group_size"]
108+
return marlin_permute_scales(
109+
weight_scale,
110+
size_k=weight_scale.shape[0] * group_size,
111+
size_n=weight_scale.shape[1],
112+
group_size=self.hf_quantization_config["group_size"],
113+
)
114+
115+
def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
116+
return awq_to_marlin_zero_points(
117+
weight_zero_point,
118+
size_k=weight_zero_point.shape[0],
119+
size_n=weight_zero_point.shape[1] * self.pack_factor,
120+
num_bits=self.hf_quantization_config["bits"],
121+
)
122+
123+
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
124+
qweight, weight_scale, qzeros = weights
125+
reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1])
126+
127+
use_atomic_add = should_use_atomic_add_reduce(
128+
m=reshaped_x.size(0),
129+
n=self.n,
130+
k=self.k,
131+
device=input_tensor.device,
132+
dtype=input_tensor.dtype,
133+
)
134+
135+
out = vllm_ops.gptq_marlin_gemm(
136+
reshaped_x,
137+
None,
138+
qweight,
139+
bias,
140+
weight_scale,
141+
None,
142+
qzeros,
143+
self.g_idx,
144+
self.g_idx_sort_indices,
145+
self.workspace,
146+
TYPE_MAP[self.hf_quantization_config["bits"]],
147+
size_m=reshaped_x.shape[0],
148+
size_n=self.n,
149+
size_k=self.k,
150+
use_atomic_add=use_atomic_add,
151+
use_fp32_reduce=True,
152+
is_zp_float=False,
153+
)
154+
155+
if bias is not None:
156+
out.add_(bias)
157+
return out
158+
159+
160+
# adapted from
161+
# https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212
162+
def is_awq_marlin_compatible(quantization_config: dict[str, Any]):
163+
# Extract data from quant config.
164+
quant_method = quantization_config.get("quant_method", "").lower()
165+
num_bits = quantization_config.get("bits")
166+
group_size = quantization_config.get("group_size")
167+
zero_point = quantization_config.get("zero_point")
168+
169+
if not torch.cuda.is_available():
170+
return False
171+
172+
if quant_method != "awq":
173+
return False
174+
175+
# If we cannot find the info needed in the config, cannot convert.
176+
if num_bits is None or group_size is None or zero_point is None:
177+
return False
178+
179+
if num_bits not in TYPE_MAP:
180+
return False
181+
182+
return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point)

0 commit comments

Comments
 (0)