Skip to content

Commit 31e7b56

Browse files
committed
sgl_kernel && vllm ops
1 parent f790b30 commit 31e7b56

File tree

17 files changed

+117
-1204
lines changed

17 files changed

+117
-1204
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import triton.language as tl
2424
from typing import Any, Callable, Dict, Optional, Tuple
2525
from lightllm.utils.log_utils import init_logger
26-
from lightllm.common.vllm_kernel import _custom_ops as ops
26+
from lightllm.utils.vllm_utils import vllm_ops
2727
from lightllm.utils.device_utils import (
2828
get_device_sm_count,
2929
get_device_sm_regs_num,
@@ -446,7 +446,7 @@ def grouped_matmul(
446446
if use_fp8_w8a8:
447447
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化
448448
if block_size_k == 0:
449-
token_inputs, token_input_scale = ops.scaled_fp8_quant(token_inputs, token_input_scale)
449+
token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant(token_inputs, token_input_scale)
450450
else:
451451
_m, _k = token_inputs.shape
452452
assert _k % block_size_k == 0

lightllm/common/fused_moe/topk_select.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import os
2121
import torch
22-
from lightllm.common.vllm_kernel import _custom_ops as ops
22+
from lightllm.utils.sgl_utils import sgl_ops
23+
from lightllm.utils.light_utils import light_ops
2324
from typing import Callable, List, Optional, Tuple
2425

2526
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
@@ -32,14 +33,18 @@ def fused_topk(
3233
renormalize: bool,
3334
):
3435
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
36+
assert (
37+
sgl_ops is not None
38+
), "sgl_kernel is not installed, you can't use the cuda fused_topk. \
39+
You can solve it by running `pip install sgl_kernel`."
3540

3641
M, _ = hidden_states.shape
3742

3843
topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device)
3944
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
4045
token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
4146

42-
ops.topk_softmax(
47+
sgl_ops.topk_softmax(
4348
topk_weights,
4449
topk_ids,
4550
token_expert_indicies,
@@ -142,14 +147,16 @@ def cuda_grouped_topk(
142147
):
143148

144149
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
150+
assert light_ops is not None, "lightllm_kernel is not installed."
151+
145152
num_tokens = gating_output.shape[0]
146153
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
147154
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
148155
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
149156
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
150157
if correction_bias is None:
151158
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
152-
ops.grouped_topk(
159+
light_ops.grouped_topk(
153160
topk_weights,
154161
correction_bias,
155162
topk_indices,

lightllm/common/quantization/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import yaml
22
import collections
33
from .registry import QUANTMETHODS
4-
from .ppl_quant import *
54
from .torchao_quant import *
6-
from .vllm_quant import *
5+
from .w8a8_quant import *
76
from .triton_quant.triton_quant import *
87
from .deepgemm_quant import *
98
from lightllm.utils.log_utils import init_logger

lightllm/common/quantization/ppl_quant.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

lightllm/common/quantization/vllm_quant.py renamed to lightllm/common/quantization/w8a8_quant.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,14 @@
55
import torch.nn.functional as F
66
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
77
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
8+
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
9+
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops
810

9-
try:
10-
HAS_VLLM = True
11-
from lightllm.common.vllm_kernel import _custom_ops as ops
12-
except:
13-
HAS_VLLM = False
1411

15-
16-
class vLLMBaseQuantizationMethod(QuantizationMethod):
12+
class BaseQuantizationMethod(QuantizationMethod):
1713
def __init__(self):
1814
super().__init__()
19-
assert HAS_VLLM, "vllm is not installed, you can't use quant api of it"
15+
assert HAS_VLLM and HAS_SGL_KERNEL, "vllm and sgl_kernel are not installed, you can't use quant api of them."
2016
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2117

2218
self.cache_manager = g_cache_manager
@@ -30,8 +26,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
3026
pass
3127

3228

33-
@QUANTMETHODS.register(["vllm-w8a8"])
34-
class vLLMw8a8QuantizationMethod(vLLMBaseQuantizationMethod):
29+
@QUANTMETHODS.register(["vllm-w8a8", "w8a8"])
30+
class w8a8QuantizationMethod(BaseQuantizationMethod):
3531
def __init__(self):
3632
super().__init__()
3733

@@ -53,7 +49,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5349
else:
5450
raise ValueError("vllm-quant Weights must be a tuple of length 2 or 3.")
5551

56-
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
52+
x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
5753
m = input_tensor.shape[0]
5854
n = qweight.shape[1]
5955
if out is None:
@@ -67,8 +63,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6763
return out
6864

6965

70-
@QUANTMETHODS.register(["vllm-fp8w8a8"])
71-
class vLLMFP8w8a8QuantizationMethod(vLLMBaseQuantizationMethod):
66+
@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"])
67+
class FP8w8a8QuantizationMethod(BaseQuantizationMethod):
7268
def __init__(self):
7369
super().__init__()
7470
self.is_moe = False
@@ -88,15 +84,15 @@ def quantize(self, weight: torch.Tensor):
8884
def quantize_scaled_mm_fp8(self, weight: torch.Tensor):
8985
if self.is_moe:
9086
return self.quantize_moe(weight)
91-
qweight, weight_scale = ops.scaled_fp8_quant(
87+
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
9288
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
9389
)
9490
return qweight.transpose(0, 1), weight_scale
9591

9692
def quantize_pingpong_fp8(self, weight: torch.Tensor):
9793
if self.is_moe:
9894
return self.quantize_moe(weight)
99-
qweight, weight_scale = ops.scaled_fp8_quant(
95+
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
10096
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
10197
)
10298
return qweight.transpose(0, 1), weight_scale
@@ -107,7 +103,7 @@ def quantize_moe(self, weight):
107103
weight_scales = []
108104
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_)
109105
for i in range(num_experts):
110-
qweight, weight_scale = ops.scaled_fp8_quant(
106+
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
111107
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False
112108
)
113109
qweights[i] = qweight
@@ -121,7 +117,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
121117
def apply_scaled_mm_fp8(
122118
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
123119
):
124-
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
120+
x_q, x_scale = vllm_ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
125121
m = input_tensor.shape[0]
126122
n = weights[0].shape[1]
127123
if out is None:
@@ -137,7 +133,9 @@ def apply_scaled_mm_fp8(
137133
def apply_pingpong_fp8(
138134
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
139135
):
140-
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=False)
136+
x_q, x_scale = vllm_ops.scaled_fp8_quant(
137+
input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=False
138+
)
141139
assert bias is None
142140
m = input_tensor.shape[0]
143141
n = weights[0].shape[1]
@@ -154,8 +152,8 @@ def apply_pingpong_fp8(
154152
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)
155153

156154

157-
@QUANTMETHODS.register(["vllm-fp8w8a8-b128"])
158-
class vLLMFP8w8a8B128QuantizationMethod(vLLMBaseQuantizationMethod):
155+
@QUANTMETHODS.register(["vllm-fp8w8a8-b128, fp8w8a8-b128"])
156+
class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod):
159157
def __init__(self):
160158
super().__init__()
161159
self.block_size = 128

lightllm/common/vllm_kernel/__init__.py

Whitespace-only changes.

lightllm/common/vllm_kernel/_custom_ops.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)