Skip to content

Commit 0adcb16

Browse files
committed
update
1 parent 31e7b56 commit 0adcb16

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

lightllm/common/quantization/w8a8_quant.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import torch.nn.functional as F
66
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
77
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
8-
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
9-
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops
8+
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm
109

1110

1211
class BaseQuantizationMethod(QuantizationMethod):
1312
def __init__(self):
1413
super().__init__()
15-
assert HAS_VLLM and HAS_SGL_KERNEL, "vllm and sgl_kernel are not installed, you can't use quant api of them."
14+
assert HAS_VLLM, "vllm are not installed, you can't use quant api of them."
1615
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1716

1817
self.cache_manager = g_cache_manager
@@ -59,7 +58,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5958
)
6059
else:
6160
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
62-
torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
61+
cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
6362
return out
6463

6564

@@ -127,7 +126,7 @@ def apply_scaled_mm_fp8(
127126
)
128127
else:
129128
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
130-
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
129+
cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
131130
return out
132131

133132
def apply_pingpong_fp8(
@@ -195,5 +194,5 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
195194
)
196195
else:
197196
input_scale = input_scale.t().contiguous().t()
198-
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
197+
cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
199198
return out

lightllm/distributed/custom_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
logger = init_logger(__name__)
3535

36-
use_vllm_custom_allreduce = os.getenv("LIGHTLLM_USE_VLLM_CUSTOM_ALLREDUCE", "1").upper() in ["ON", "TRUE", "1"]
36+
use_vllm_custom_allreduce = os.getenv("LIGHTLLM_USE_VLLM_CUSTOM_ALLREDUCE", "0").upper() in ["ON", "TRUE", "1"]
3737
if use_vllm_custom_allreduce:
3838
# Use vllm custom allreduce
3939
ops = vllm_ops

lightllm/utils/vllm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from lightllm.utils.log_utils import init_logger
23

34
logger = init_logger(__name__)
@@ -6,9 +7,10 @@
67

78
vllm_ops = ops
89
HAS_VLLM = True
10+
cutlass_scaled_mm = torch.ops._C.cutlass_scaled_mm
911
except:
1012
HAS_VLLM = False
11-
sgl_allreduce_ops = None
13+
cutlass_scaled_mm = None
1214
logger.warning(
1315
"vllm is not installed, you can't use the api of it. \
1416
You can solve it by running `pip install vllm`."

0 commit comments

Comments
 (0)