Skip to content

Commit 7c05dec

Browse files
committed
import vllm cutlass fp8w8a8-block gemm
1 parent 5ab69f2 commit 7c05dec

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lightllm/common/quantization/vllm_quant.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .quantize_method import QuantizationMethod
44
from .registry import QUANTMETHODS
55
import torch.nn.functional as F
6+
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
67

78
try:
89
HAS_VLLM = True
@@ -119,7 +120,21 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
119120
def apply_scaled_mm_fp8(
120121
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
121122
):
122-
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
123+
qweight, weight_scale, input_scale = weights
124+
if weight_scale.shape == 2:
125+
# block-wise
126+
m, k = input_tensor.shape
127+
if input_scale is None:
128+
input_scale = self.cache_manager.alloc_tensor(
129+
(m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False
130+
)
131+
qinput_tensor = self.cache_manager.alloc_tensor(
132+
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
133+
)
134+
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale)
135+
qinput_tensor, input_scale = ops.scaled_fp8_quant(
136+
input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True
137+
)
123138
m = input_tensor.shape[0]
124139
n = weights[0].shape[1]
125140
if out is None:
@@ -129,7 +144,7 @@ def apply_scaled_mm_fp8(
129144
)
130145
else:
131146
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
132-
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
147+
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
133148
return out
134149

135150
def apply_pingpong_fp8(

0 commit comments

Comments
 (0)