Skip to content

Commit 36d162a

Browse files
committed
update vllmquant
1 parent 3d9c1bc commit 36d162a

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,30 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
2929
class vLLMw8a8QuantizationMethod(vLLMBaseQuantizationMethod):
3030
def __init__(self):
3131
super().__init__()
32-
self.input_scale = None
3332

3433
def quantize(self, weight: torch.Tensor):
3534
if isinstance(weight, tuple):
36-
if len(weight) == 3:
37-
self.input_scale = weight[-1]
38-
return weight[0].transpose(0, 1).cuda(), weight[1]
35+
return (weight[0].transpose(0, 1).cuda(),) + weight[1:]
3936
weight = weight.float()
4037
scale = weight.abs().max(dim=-1)[0] / 127
4138
weight = weight.transpose(0, 1) / scale.reshape(1, -1)
4239
weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8)
4340
return weight.cuda(), scale.cuda()
4441

4542
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
46-
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=self.input_scale, azp=None, symmetric=True)
43+
input_scale = None
44+
if len(weights) == 3:
45+
qweight, weight_scale, input_scale = weights
46+
elif len(weights) == 2:
47+
qweight, weight_scale = weights
48+
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
4749
m = input_tensor.shape[0]
48-
n = weights[0].shape[1]
50+
n = qweight.shape[1]
4951
if out is None:
5052
out = g_cache_manager.alloc_tensor(
5153
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
5254
)
53-
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
55+
torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
5456
return out
5557

5658

0 commit comments

Comments
 (0)