Skip to content

Commit 59abd25

Browse files
committed
fix block mm
1 parent ffaee7d commit 59abd25

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .registry import QUANTMETHODS
55
import torch.nn.functional as F
66
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
7+
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
78

89
try:
910
HAS_VLLM = True
@@ -171,7 +172,7 @@ def apply_pingpong_fp8(
171172
class vLLMFP8w8a8B128QuantizationMethod(vLLMBaseQuantizationMethod):
172173
def __init__(self):
173174
super().__init__()
174-
self.blcok_size = 128
175+
self.block_size = 128
175176

176177
def quantize(self, weight: torch.Tensor):
177178
if self.is_moe:
@@ -200,5 +201,18 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
200201
)
201202
else:
202203
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
203-
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
204+
if n == 576:
205+
w8a8_block_fp8_matmul(
206+
qinput_tensor,
207+
qweight,
208+
input_scale,
209+
weight_scale,
210+
out,
211+
(self.block_size, self.block_size),
212+
dtype=input_tensor.dtype,
213+
)
214+
else:
215+
qweight = qweight.t().contiguous().t()
216+
input_scale = input_scale.t().contiguous().t()
217+
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
204218
return out

0 commit comments

Comments
 (0)