|
5 | 5 | import torch.nn.functional as F |
6 | 6 | from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 |
7 | 7 | from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul |
8 | | -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops |
9 | | -from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops |
| 8 | +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm |
10 | 9 |
|
11 | 10 |
|
12 | 11 | class BaseQuantizationMethod(QuantizationMethod): |
13 | 12 | def __init__(self): |
14 | 13 | super().__init__() |
15 | | - assert HAS_VLLM and HAS_SGL_KERNEL, "vllm and sgl_kernel are not installed, you can't use quant api of them." |
| 14 | + assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." |
16 | 15 | from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager |
17 | 16 |
|
18 | 17 | self.cache_manager = g_cache_manager |
@@ -59,7 +58,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ |
59 | 58 | ) |
60 | 59 | else: |
61 | 60 | out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) |
62 | | - torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) |
| 61 | + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) |
63 | 62 | return out |
64 | 63 |
|
65 | 64 |
|
@@ -127,7 +126,7 @@ def apply_scaled_mm_fp8( |
127 | 126 | ) |
128 | 127 | else: |
129 | 128 | out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) |
130 | | - torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) |
| 129 | + cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias) |
131 | 130 | return out |
132 | 131 |
|
133 | 132 | def apply_pingpong_fp8( |
@@ -195,5 +194,5 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ |
195 | 194 | ) |
196 | 195 | else: |
197 | 196 | input_scale = input_scale.t().contiguous().t() |
198 | | - torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) |
| 197 | + cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) |
199 | 198 | return out |
0 commit comments