Skip to content

Commit fba788d

Browse files
committed
update fp8 quant
1 parent 7f92a33 commit fba788d

File tree

8 files changed

+68
-85
lines changed

8 files changed

+68
-85
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional, Tuple, List, Dict, Any
55
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
66
from .base_weight import BaseWeight
7-
from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl, masked_group_gemm, tma_aligned_quantize
7+
from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl, masked_group_gemm
88
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
99
from lightllm.distributed import dist_group_manager
1010
from lightllm.common.fused_moe.topk_select import select_experts
@@ -228,9 +228,7 @@ def select_experts_and_quant_input(
228228
if w1.ndim == 3:
229229
block_size_k = w1.shape[2] // w1_scale.shape[2]
230230
assert block_size_k == 128, "block_size_k must be 128"
231-
input_scale = torch.empty((M, K // block_size_k), dtype=torch.float32, device=hidden_states.device)
232-
qinput_tensor = torch.empty((M, K), dtype=w1.dtype, device=hidden_states.device)
233-
per_token_group_quant_fp8(hidden_states, block_size_k, qinput_tensor, input_scale)
231+
input_scale, qinput_tensor = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype)
234232
return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale)
235233

236234
def dispatch(
@@ -340,7 +338,9 @@ def prefilled_group_gemm(
340338
silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype)
341339

342340
silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out)
343-
qsilu_out, qsilu_out_scale = tma_aligned_quantize(silu_out)
341+
qsilu_out, qsilu_out_scale = per_token_group_quant_fp8(
342+
silu_out, self.block_size, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True
343+
)
344344

345345
# groupgemm (contiguous layout)
346346
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,9 @@ def grouped_matmul(
526526
else:
527527
_m, _k = token_inputs.shape
528528
assert _k % block_size_k == 0
529-
input_scale = alloc_tensor_func((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
530-
qinput_tensor = alloc_tensor_func((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
531-
per_token_group_quant_fp8(token_inputs, block_size_k, qinput_tensor, input_scale)
532-
token_inputs, token_input_scale = qinput_tensor, input_scale
529+
token_inputs, token_input_scale = per_token_group_quant_fp8(
530+
token_inputs, block_size_k, dtype=expert_weights.dtype
531+
)
533532

534533
if reused_mblock_infos is None:
535534
mblocks_to_expert_id, mblocks_to_m_index = moe_align2(token_num_mul_topk_num, expert_to_token_num, BLOCK_SIZE_M)
@@ -627,13 +626,17 @@ def fused_experts_impl(
627626
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
628627
topk_num = topk_ids.shape[1]
629628
M = min(num_tokens, CHUNK_SIZE)
630-
631-
intermediate_cache13_shared = alloc_tensor_func((M, topk_num, max(N, w2.shape[1])), device=hidden_states.device, dtype=hidden_states.dtype)
632-
intermediate_cache1 = intermediate_cache13_shared.view(-1)[:(M * topk_num * N)].view(M, topk_num, N)
629+
630+
intermediate_cache13_shared = alloc_tensor_func(
631+
(M, topk_num, max(N, w2.shape[1])), device=hidden_states.device, dtype=hidden_states.dtype
632+
)
633+
intermediate_cache1 = intermediate_cache13_shared.view(-1)[: (M * topk_num * N)].view(M, topk_num, N)
633634
intermediate_cache2 = alloc_tensor_func(
634635
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
635636
)
636-
intermediate_cache3 = intermediate_cache13_shared.view(-1)[:(M * topk_num * w2.shape[1])].view(M, topk_num, w2.shape[1])
637+
intermediate_cache3 = intermediate_cache13_shared.view(-1)[: (M * topk_num * w2.shape[1])].view(
638+
M, topk_num, w2.shape[1]
639+
)
637640

638641
if inplace:
639642
out_hidden_states = hidden_states

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,6 @@
2626
logger.warning("no deepep or deep_gemm")
2727

2828

29-
def tma_aligned_quantize(
30-
input_tensor: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
31-
) -> Tuple[torch.Tensor, torch.Tensor]:
32-
m, k = input_tensor.shape
33-
input_scale = torch.empty((m, k // block_size), dtype=torch.float32, device=input_tensor.device)
34-
qinput_tensor = torch.empty((m, k), dtype=dtype, device=input_tensor.device)
35-
per_token_group_quant_fp8(input_tensor, block_size, qinput_tensor, input_scale)
36-
input_scale = tma_align_input_scale(input_scale)
37-
return qinput_tensor, input_scale
38-
39-
4029
def masked_group_gemm(
4130
recv_x: Tuple[torch.Tensor],
4231
masked_m: torch.Tensor,
@@ -106,9 +95,7 @@ def fused_experts_impl(
10695

10796
combined_x = None
10897
if is_prefill:
109-
input_scale = torch.empty((M, K // block_size_k), dtype=torch.float32, device=hidden_states.device)
110-
qinput_tensor = torch.empty((M, K), dtype=w1.dtype, device=hidden_states.device)
111-
per_token_group_quant_fp8(hidden_states, block_size_k, qinput_tensor, input_scale)
98+
qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype)
11299

113100
# get_dispatch_layout
114101
(
@@ -186,7 +173,9 @@ def fused_experts_impl(
186173
silu_out = torch.empty((all_tokens, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
187174

188175
silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out)
189-
qsilu_out, qsilu_out_scale = tma_aligned_quantize(silu_out)
176+
qsilu_out, qsilu_out_scale = per_token_group_quant_fp8(
177+
silu_out, block_size_k, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True
178+
)
190179

191180
# groupgemm (contiguous layout)
192181
gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype)

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,22 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5151
else:
5252
qweight, weight_scale = weights
5353
input_scale = None
54+
alloc_func = torch.empty
55+
if use_custom_tensor_mananger:
56+
alloc_func = self.cache_manager.alloc_tensor
5457
m, k = input_tensor.shape
5558
n = weights[0].shape[1]
5659
if input_scale is None:
57-
qinput_tensor = self.cache_manager.alloc_tensor(
58-
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
59-
)
60-
_, input_scale = per_token_group_quant_fp8(
61-
input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True
60+
qinput_tensor, input_scale = per_token_group_quant_fp8(
61+
input_tensor,
62+
self.block_size,
63+
dtype=qweight.dtype,
64+
column_major_scales=True,
65+
scale_tma_aligned=True,
66+
alloc_func=alloc_func,
6267
)
6368

6469
if out is None:
65-
if use_custom_tensor_mananger:
66-
out = self.cache_manager.alloc_tensor(
67-
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
68-
)
69-
else:
70-
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
70+
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
7171
deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out)
7272
return out

lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,24 @@ def lightllm_per_token_group_quant_fp8(
108108
def per_token_group_quant_fp8(
109109
x: torch.Tensor,
110110
group_size: int,
111-
x_q: torch.Tensor,
112-
x_s: torch.Tensor = None,
113111
eps: float = 1e-10,
114112
dtype: torch.dtype = torch.float8_e4m3fn,
115113
column_major_scales: bool = False,
116114
scale_tma_aligned: bool = False,
117115
alloc_func: Callable = torch.empty,
118116
):
117+
x_q = alloc_func(x.shape, device=x.device, dtype=dtype)
118+
x_s = None
119119
# Adapted from
120120
# https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290
121121
if HAS_SGL_KERNEL:
122122
finfo = torch.finfo(dtype)
123123
fp8_max, fp8_min = finfo.max, finfo.min
124+
125+
# 创建scale张量
124126
if column_major_scales:
125127
if scale_tma_aligned:
126-
# aligned to 4 * sizeof(float)
128+
# 对齐到4 * sizeof(float)
127129
aligned_size = (x.shape[-2] + 3) // 4 * 4
128130
x_s = alloc_func(
129131
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
@@ -137,16 +139,24 @@ def per_token_group_quant_fp8(
137139
dtype=torch.float32,
138140
).permute(-1, -2)
139141
else:
140-
if x_s is None:
141-
x_s = alloc_func(
142-
x.shape[:-1] + (x.shape[-1] // group_size,),
143-
device=x.device,
144-
dtype=torch.float32,
145-
)
142+
x_s = alloc_func(
143+
x.shape[:-1] + (x.shape[-1] // group_size,),
144+
device=x.device,
145+
dtype=torch.float32,
146+
)
147+
148+
# 使用SGL kernel进行量化
146149
sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False)
147150
else:
151+
# 使用LightLLM kernel进行量化
152+
x_s = alloc_func(
153+
x.shape[:-1] + (x.shape[-1] // group_size,),
154+
device=x.device,
155+
dtype=torch.float32,
156+
)
148157
lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn)
149-
158+
if column_major_scales and scale_tma_aligned:
159+
x_s = tma_align_input_scale(x_s)
150160
return x_q, x_s
151161

152162

@@ -237,9 +247,9 @@ def test_tma_align():
237247
m = 576
238248
k = 8192
239249
x = torch.randn((m, k // 128), dtype=torch.float32).cuda()
250+
240251
for _ in range(10):
241252
x_padded = tma_align_input_scale(x)
242-
print(x_padded.shape)
243253
import time
244254

245255
torch.cuda.synchronize()
@@ -255,11 +265,9 @@ def test_tma_align():
255265
def test_per_token_group_quant_fp8():
256266
group_size = 128
257267
x = torch.randn((1024, 8192), dtype=torch.bfloat16).cuda()
258-
259-
x_q = torch.randn((1024, 8192)).cuda().to(torch.float8_e4m3fn)
260268
# x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda()
261269
# x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t()
262-
_, x_s = per_token_group_quant_fp8(x, group_size, x_q, None, column_major_scales=True)
270+
x_q, x_s = per_token_group_quant_fp8(x, group_size, column_major_scales=True, scale_tma_aligned=True)
263271
x_s = x_s[:1024]
264272
th_x_q, th_x_s = torch_quant(x, group_size)
265273
print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max())
@@ -268,4 +276,4 @@ def test_per_token_group_quant_fp8():
268276

269277
if __name__ == "__main__":
270278
test_per_token_group_quant_fp8()
271-
# test_tma_align()
279+
test_tma_align()

lightllm/common/quantization/triton_quant/triton_quant.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,20 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
3838
qweight, weight_scale, input_scale = weights
3939
m, k = input_tensor.shape
4040
n = qweight.shape[1]
41+
alloc_func = torch.empty
42+
if use_custom_tensor_mananger:
43+
alloc_func = self.cache_manager.alloc_tensor
4144
if input_scale is None:
42-
input_scale = self.cache_manager.alloc_tensor(
43-
(m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False
45+
input_tensor_q, input_scale = per_token_group_quant_fp8(
46+
input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func
4447
)
45-
input_tensor_q = self.cache_manager.alloc_tensor(
46-
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
47-
)
48-
per_token_group_quant_fp8(input_tensor, self.block_size, input_tensor_q, input_scale)
4948
else:
5049
# TODO
5150
raise "statci input scale is not supported by triton fp8 block gemm kernel."
5251
m = input_tensor.shape[0]
5352
n = qweight.shape[1]
5453
if out is None:
55-
if use_custom_tensor_mananger:
56-
out = self.cache_manager.alloc_tensor(
57-
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
58-
)
59-
else:
60-
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
54+
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
6155
w8a8_block_fp8_matmul(
6256
input_tensor_q,
6357
qweight,

lightllm/common/quantization/w8a8_quant.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,15 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
131131
qweight, weight_scale, input_scale = weights
132132
m, k = input_tensor.shape
133133
n = weights[0].shape[1]
134+
alloc_func = torch.empty
135+
if use_custom_tensor_mananger:
136+
alloc_func = self.cache_manager.alloc_tensor
134137
if input_scale is None:
135-
input_scale = self.cache_manager.alloc_tensor(
136-
(m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False
138+
qinput_tensor, input_scale = per_token_group_quant_fp8(
139+
input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func
137140
)
138-
qinput_tensor = self.cache_manager.alloc_tensor(
139-
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
140-
)
141-
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale)
142141
if out is None:
143-
if use_custom_tensor_mananger:
144-
out = self.cache_manager.alloc_tensor(
145-
(m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False
146-
)
147-
else:
148-
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
142+
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
149143
if n % 128 != 0:
150144
w8a8_block_fp8_matmul(
151145
qinput_tensor,

unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,15 @@ def test_silu_and_mul_masked(expert_num, token_num, hidden_dim):
3232
)
3333

3434
true_out_tensor_mid = torch.randn((expert_num, token_num, hidden_dim // 2), dtype=torch.float16, device="cuda")
35-
true_out_tensor = torch.empty((expert_num, token_num, hidden_dim // 2), dtype=torch.float8_e4m3fn, device="cuda")
36-
true_out_scale_tensor = torch.randn(
37-
(expert_num, token_num, hidden_dim // 2 // quant_group_size), dtype=torch.float32, device="cuda"
38-
)
3935

4036
masked_m = [random.randint(0, token_num) for _ in range(expert_num)]
4137
masked_m = torch.tensor(masked_m, dtype=torch.int32, device="cuda")
4238

4339
silu_and_mul_fwd(in_tensor.view(-1, hidden_dim), true_out_tensor_mid.view(-1, hidden_dim // 2))
44-
per_token_group_quant_fp8(
40+
true_out_tensor, true_out_scale_tensor = per_token_group_quant_fp8(
4541
true_out_tensor_mid.view(-1, hidden_dim // 2),
4642
quant_group_size,
47-
true_out_tensor.view(-1, hidden_dim // 2),
48-
true_out_scale_tensor.view(-1, hidden_dim // 2 // quant_group_size),
43+
alloc_func=torch.empty,
4944
)
5045

5146
silu_and_mul_masked_post_quant_fwd(in_tensor, out_tensor, out_scale_tensor, quant_group_size, masked_m)

0 commit comments

Comments
 (0)