Skip to content

Commit ceb4bc1

Browse files
committed
update
1 parent 0adcb16 commit ceb4bc1

File tree

6 files changed

+18
-66
lines changed

6 files changed

+18
-66
lines changed

lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
import triton.language as tl
44

55
from lightllm.common.kernel_config import KernelConfigs
6+
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops
67
from frozendict import frozendict
78
from functools import lru_cache
89
from typing import Any, Dict, List, Optional, Tuple
910

10-
try:
11-
HAS_SGLANG_KERNEL = True
12-
from sgl_kernel import sgl_per_token_group_quant_fp8
13-
except:
14-
HAS_SGLANG_KERNEL = False
15-
1611
try:
1712
from deep_gemm import ceil_div
1813
except:
@@ -118,10 +113,10 @@ def per_token_group_quant_fp8(
118113
eps: float = 1e-10,
119114
dtype: torch.dtype = torch.float8_e4m3fn,
120115
):
121-
if HAS_SGLANG_KERNEL:
116+
if HAS_SGL_KERNEL:
122117
finfo = torch.finfo(dtype)
123118
fp8_max, fp8_min = finfo.max, finfo.min
124-
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max)
119+
sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max)
125120
else:
126121
lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn)
127122

lightllm/common/quantization/w8a8_quant.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -67,35 +67,15 @@ class FP8w8a8QuantizationMethod(BaseQuantizationMethod):
6767
def __init__(self):
6868
super().__init__()
6969
self.is_moe = False
70-
# PINGPONG_FP8_GEMM is per tensor quant way.
71-
self.use_pingpong_fp8_gemm = os.getenv("ENABLE_PINGPONG_FP8_GEMM", "0").upper() in ["ON", "TRUE", "1"]
72-
73-
if self.use_pingpong_fp8_gemm:
74-
self.quantize = self.quantize_pingpong_fp8
75-
self.apply = self.apply_pingpong_fp8
76-
else:
77-
self.quantize = self.quantize_scaled_mm_fp8
78-
self.apply = self.apply_scaled_mm_fp8
7970

8071
def quantize(self, weight: torch.Tensor):
81-
raise Exception("This function needs to be bound.")
82-
83-
def quantize_scaled_mm_fp8(self, weight: torch.Tensor):
8472
if self.is_moe:
8573
return self.quantize_moe(weight)
8674
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
8775
weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
8876
)
8977
return qweight.transpose(0, 1), weight_scale
9078

91-
def quantize_pingpong_fp8(self, weight: torch.Tensor):
92-
if self.is_moe:
93-
return self.quantize_moe(weight)
94-
qweight, weight_scale = vllm_ops.scaled_fp8_quant(
95-
weight.contiguous().cuda(), scale=None, use_per_token_if_dynamic=False
96-
)
97-
return qweight.transpose(0, 1), weight_scale
98-
9979
def quantize_moe(self, weight):
10080
num_experts = weight.shape[0]
10181
qweights = []
@@ -111,11 +91,6 @@ def quantize_moe(self, weight):
11191
return qweights, weight_scale
11292

11393
def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True):
114-
raise Exception("This function needs to be bound.")
115-
116-
def apply_scaled_mm_fp8(
117-
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
118-
):
11994
x_q, x_scale = vllm_ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
12095
m = input_tensor.shape[0]
12196
n = weights[0].shape[1]
@@ -129,27 +104,6 @@ def apply_scaled_mm_fp8(
129104
cutlass_scaled_mm(out, x_q, weights[0], x_scale, weights[1], bias)
130105
return out
131106

132-
def apply_pingpong_fp8(
133-
self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True
134-
):
135-
x_q, x_scale = vllm_ops.scaled_fp8_quant(
136-
input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=False
137-
)
138-
assert bias is None
139-
m = input_tensor.shape[0]
140-
n = weights[0].shape[1]
141-
if out is None:
142-
if use_custom_tensor_mananger:
143-
out = self.cache_manager.alloc_tensor(
144-
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
145-
)
146-
else:
147-
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
148-
149-
from fp8_pingpong_gemm import cutlass_scaled_mm
150-
151-
return cutlass_scaled_mm(x_q, weights[0], x_scale, weights[1], out)
152-
153107

154108
@QUANTMETHODS.register(["vllm-fp8w8a8-b128, fp8w8a8-b128"])
155109
class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod):

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,10 @@
2929
from lightllm.utils.envs_utils import get_env_start_args
3030
from lightllm.utils.dist_utils import get_global_world_size
3131
from lightllm.utils.log_utils import init_logger
32+
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
3233

3334
logger = init_logger(__name__)
3435

35-
try:
36-
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
37-
from sgl_kernel import merge_state_v2
38-
except:
39-
logger.warning("sgl_kernel is not installed, or the installed version does not support fa3!")
40-
4136

4237
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
4338
def __init__(self, layer_num, network_config, mode=[]):
@@ -311,6 +306,7 @@ def _context_attention_flashattention_kernel_with_CC(
311306
layer_weight: Deepseek2TransformerLayerWeight,
312307
out=None,
313308
) -> torch.Tensor:
309+
assert flash_attn_varlen_func is not None, "fa3 is not available. It requires sm90 and above."
314310
k_nope, k_rope, v = self._decompress_kv(
315311
kv,
316312
infer_state,

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@
2929

3030
logger = init_logger(__name__)
3131

32-
try:
33-
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
34-
except:
35-
logger.warning("sgl_kernel is not installed, or the installed version does not support fa3!")
32+
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
3633

3734

3835
class LlamaTransformerLayerInfer(TransformerLayerInferTpl):
@@ -252,6 +249,7 @@ def _context_attention_kernel_ppl_int8kv(
252249
return o_tensor
253250

254251
def _context_attention_flashattention(self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None):
252+
assert flash_attn_with_kvcache is not None, "fa3 is not available. It requires sm90 and above."
255253
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
256254
-1, 1, self.tp_k_head_num_, self.head_dim_
257255
)

lightllm/utils/sgl_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33
logger = init_logger(__name__)
44
try:
55
import sgl_kernel
6-
import sgl_kernel.allreduce as sgl_allreduce_ops
76

87
sgl_ops = sgl_kernel
8+
sgl_allreduce_ops = sgl_ops.allreduce
9+
if sgl_ops.flash_attn.is_fa3_supported():
10+
flash_attn_varlen_func = sgl_ops.flash_attn.flash_attn_varlen_func
11+
flash_attn_with_kvcache = sgl_ops.flash_attn.flash_attn_with_kvcache
12+
merge_state_v2 = sgl_ops.flash_attn.merge_state_v2
13+
else:
14+
flash_attn_varlen_func = None
15+
flash_attn_with_kvcache = None
16+
merge_state_v2 = None
17+
logger.warning("Fa3 is only supported on sm90 and above.")
918
HAS_SGL_KERNEL = True
1019
except:
1120
HAS_SGL_KERNEL = False

test/model/model_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an
362362
total_token_num,
363363
b_ready_cache_len,
364364
),
365-
log_dir=f"./logs_sglang_4k/forward_prefill_{model_kvargs['rank_id']}",
365+
log_dir=f"./logs/forward_prefill_{model_kvargs['rank_id']}",
366366
)
367367
else:
368368
torch_profile(

0 commit comments

Comments
 (0)