Skip to content

Commit 7a18054

Browse files
sgl_kernel && vllm ops (#890)
Co-authored-by: baishihao <[email protected]> Co-authored-by: hiworldwzj <[email protected]> Co-authored-by: wangzaijun <[email protected]>
1 parent f790b30 commit 7a18054

File tree

22 files changed

+148
-1276
lines changed

22 files changed

+148
-1276
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import triton.language as tl
2424
from typing import Any, Callable, Dict, Optional, Tuple
2525
from lightllm.utils.log_utils import init_logger
26-
from lightllm.common.vllm_kernel import _custom_ops as ops
26+
from lightllm.utils.vllm_utils import vllm_ops
2727
from lightllm.utils.device_utils import (
2828
get_device_sm_count,
2929
get_device_sm_regs_num,
@@ -446,7 +446,7 @@ def grouped_matmul(
446446
if use_fp8_w8a8:
447447
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化
448448
if block_size_k == 0:
449-
token_inputs, token_input_scale = ops.scaled_fp8_quant(token_inputs, token_input_scale)
449+
token_inputs, token_input_scale = vllm_ops.scaled_fp8_quant(token_inputs, token_input_scale)
450450
else:
451451
_m, _k = token_inputs.shape
452452
assert _k % block_size_k == 0

lightllm/common/fused_moe/topk_select.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import os
2121
import torch
22-
from lightllm.common.vllm_kernel import _custom_ops as ops
22+
from lightllm.utils.sgl_utils import sgl_ops
23+
from lightllm.utils.light_utils import light_ops
2324
from typing import Callable, List, Optional, Tuple
2425

2526
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
@@ -32,14 +33,18 @@ def fused_topk(
3233
renormalize: bool,
3334
):
3435
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
36+
assert (
37+
sgl_ops is not None
38+
), "sgl_kernel is not installed, you can't use the cuda fused_topk. \
39+
You can solve it by running `pip install sgl_kernel`."
3540

3641
M, _ = hidden_states.shape
3742

3843
topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device)
3944
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
4045
token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
4146

42-
ops.topk_softmax(
47+
sgl_ops.topk_softmax(
4348
topk_weights,
4449
topk_ids,
4550
token_expert_indicies,
@@ -142,14 +147,16 @@ def cuda_grouped_topk(
142147
):
143148

144149
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
150+
assert light_ops is not None, "lightllm_kernel is not installed."
151+
145152
num_tokens = gating_output.shape[0]
146153
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
147154
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
148155
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
149156
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
150157
if correction_bias is None:
151158
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
152-
ops.grouped_topk(
159+
light_ops.grouped_topk(
153160
topk_weights,
154161
correction_bias,
155162
topk_indices,

lightllm/common/quantization/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import yaml
22
import collections
33
from .registry import QUANTMETHODS
4-
from .ppl_quant import *
54
from .torchao_quant import *
6-
from .vllm_quant import *
5+
from .w8a8_quant import *
76
from .triton_quant.triton_quant import *
87
from .deepgemm_quant import *
98
from lightllm.utils.log_utils import init_logger

lightllm/common/quantization/ppl_quant.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
import triton.language as tl
44

55
from lightllm.common.kernel_config import KernelConfigs
6+
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops
67
from frozendict import frozendict
78
from functools import lru_cache
89
from typing import Any, Dict, List, Optional, Tuple
910

10-
try:
11-
HAS_SGLANG_KERNEL = True
12-
from sgl_kernel import sgl_per_token_group_quant_fp8
13-
except:
14-
HAS_SGLANG_KERNEL = False
15-
1611
try:
1712
from deep_gemm import ceil_div
1813
except:
@@ -118,10 +113,10 @@ def per_token_group_quant_fp8(
118113
eps: float = 1e-10,
119114
dtype: torch.dtype = torch.float8_e4m3fn,
120115
):
121-
if HAS_SGLANG_KERNEL:
116+
if HAS_SGL_KERNEL:
122117
finfo = torch.finfo(dtype)
123118
fp8_max, fp8_min = finfo.max, finfo.min
124-
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max)
119+
sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max)
125120
else:
126121
lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn)
127122

lightllm/common/quantization/vllm_quant.py renamed to lightllm/common/quantization/w8a8_quant.py

Lines changed: 16 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,13 @@
55
import torch.nn.functional as F
66
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
77
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
8+
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm
89

9-
try:
10-
HAS_VLLM = True
11-
from lightllm.common.vllm_kernel import _custom_ops as ops
12-
except:
13-
HAS_VLLM = False
1410

15-
16-
class vLLMBaseQuantizationMethod(QuantizationMethod):
11+
class BaseQuantizationMethod(QuantizationMethod):
1712
def __init__(self):
1813
super().__init__()
19-
assert HAS_VLLM, "vllm is not installed, you can't use quant api of it"
14+
assert HAS_VLLM, "vllm are not installed, you can't use quant api of them."
2015
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2116

2217
self.cache_manager = g_cache_manager
@@ -30,8 +25,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
3025
pass
3126

3227

33-
@QUANTMETHODS.register(["vllm-w8a8"])
34-
class vLLMw8a8QuantizationMethod(vLLMBaseQuantizationMethod):
28+
@QUANTMETHODS.register(["vllm-w8a8", "w8a8"])
29+
class w8a8QuantizationMethod(BaseQuantizationMethod):
3530
def __init__(self):
3631
super().__init__()
3732

@@ -53,7 +48,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5348
else:
5449
raise ValueError("vllm-quant Weights must be a tuple of length 2 or 3.")
5550

56-
x_q, x_scale, x_zp = ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
51+
x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True)
5752
m = input_tensor.shape[0]
5853
n = qweight.shape[1]
5954
if out is None:
@@ -63,51 +58,31 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6358
)
6459
else:
6560
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
66-
torch.ops._C.cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
61+
cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias)
6762
return out
6863

6964

70-
@QUANTMETHODS.register(["vllm-fp8w8a8"])
71-
class vLLMFP8w8a8QuantizationMethod(vLLMBaseQuantizationMethod):
65+
@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"])
66+
class FP8w8a8QuantizationMethod(BaseQuantizationMethod):
7267
def __init__(self):
7368
super().__init__()
7469
self.is_moe = False
75-
# PINGPONG_FP8_GEMM is per tensor quant way.
76-
self.use_pingpong_fp8_gemm = os.getenv("ENABLE_PINGPONG_FP8_GEMM", "0").upper() in ["ON", "TRUE", "1"]
77-
78-
if self.use_pingpong_fp8_gemm:
79-
self.quantize = self.quantize_pingpong_fp8
80-
self.apply = self.apply_pingpong_fp8
81-
else:
82-
self.quantize = self.quantize_scaled_mm_fp8
83-
self.apply = self.apply_scaled_mm_fp8
8470

8571
def quantize(self, weight: torch.Tensor):
86-
raise Exception("This function needs to be bound.")
87-
88-
def quantize_scaled_mm_fp8(self, weight: torch.Tensor):
8972
if self.is_moe:
9073
return self.quantize_moe(weight)
91-
qweight, weight_scale = ops.scaled_fp8_quant(
74+
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
9275
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
9376
)
9477
return qweight.transpose(0, 1), weight_scale
9578

96-
def quantize_pingpong_fp8(self, weight: torch.Tensor):
97-
if self.is_moe:
98-
return self.quantize_moe(weight)
99-
qweight, weight_scale = ops.scaled_fp8_quant(
100-
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
101-
)
102-
return qweight.transpose(0, 1), weight_scale
103-
10479
def quantize_moe(self, weight):
10580
num_experts = weight.shape[0]
10681
qweights = []
10782
weight_scales = []
10883
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_)
10984
for i in range(num_experts):
110-
qweight, weight_scale = ops.scaled_fp8_quant(
85+
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
11186
weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=False
11287
)
11388
qweights[i] = qweight
@@ -116,12 +91,7 @@ def quantize_moe(self, weight):
11691
return qweights, weight_scale
11792

11893
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
119-
raise Exception("This function needs to be bound.")
120-
121-
def apply_scaled_mm_fp8(
122-
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
123-
):
124-
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
94+
x_q, x_scale = vllm_ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
12595
m = input_tensor.shape[0]
12696
n = weights[0].shape[1]
12797
if out is None:
@@ -131,31 +101,12 @@ def apply_scaled_mm_fp8(
131101
)
132102
else:
133103
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
134-
torch.ops._C.cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
104+
cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
135105
return out
136106

137-
def apply_pingpong_fp8(
138-
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
139-
):
140-
x_q, x_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=False)
141-
assert bias is None
142-
m = input_tensor.shape[0]
143-
n = weights[0].shape[1]
144-
if out is None:
145-
if use_custom_tensor_mananger:
146-
out = self.cache_manager.alloc_tensor(
147-
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
148-
)
149-
else:
150-
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
151-
152-
from fp8_pingpong_gemm import cutlass_scaled_mm
153-
154-
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)
155-
156107

157-
@QUANTMETHODS.register(["vllm-fp8w8a8-b128"])
158-
class vLLMFP8w8a8B128QuantizationMethod(vLLMBaseQuantizationMethod):
108+
@QUANTMETHODS.register(["vllm-fp8w8a8-b128, fp8w8a8-b128"])
109+
class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod):
159110
def __init__(self):
160111
super().__init__()
161112
self.block_size = 128
@@ -197,5 +148,5 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
197148
)
198149
else:
199150
input_scale = input_scale.t().contiguous().t()
200-
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
151+
cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
201152
return out

lightllm/common/vllm_kernel/__init__.py

Whitespace-only changes.

lightllm/common/vllm_kernel/_custom_ops.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)