Skip to content

Commit a5265c4

Browse files
add triton_softmax_topk (#912)
Co-authored-by: wangzaijun <wzjhelloworld@qq.com>
1 parent 6360293 commit a5265c4

File tree

3 files changed

+218
-6
lines changed

3 files changed

+218
-6
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def softmax_topk_kernel(
8+
topk_weights_ptr,
9+
topk_indices_ptr,
10+
gating_output_ptr,
11+
input_row_stride,
12+
output_weights_row_stride,
13+
output_indices_row_stride,
14+
n_rows,
15+
n_cols,
16+
BLOCK_SIZE: tl.constexpr,
17+
BLOCK_TOPK: tl.constexpr,
18+
top_k: tl.constexpr,
19+
NEED_MASK: tl.constexpr,
20+
RENORM: tl.constexpr,
21+
):
22+
row_idx = tl.program_id(0)
23+
24+
row_input_ptr = gating_output_ptr + row_idx * input_row_stride
25+
row_weights_ptr = topk_weights_ptr + row_idx * output_weights_row_stride
26+
row_indices_ptr = topk_indices_ptr + row_idx * output_indices_row_stride
27+
28+
offsets = tl.arange(0, BLOCK_SIZE)
29+
if NEED_MASK:
30+
mask = offsets < n_cols
31+
values = tl.load(row_input_ptr + offsets, mask=mask, other=-float("inf"))
32+
else:
33+
values = tl.load(row_input_ptr + offsets)
34+
35+
current_max = tl.max(values, axis=0)
36+
values = values - current_max
37+
numerators = tl.exp(values)
38+
denom = tl.sum(numerators, axis=0)
39+
40+
sum_prob = 0.0
41+
for i in range(top_k):
42+
logit = tl.max(values, axis=0)
43+
idx = tl.argmax(values, axis=0)
44+
45+
prob = tl.exp(logit) / denom
46+
sum_prob += prob
47+
48+
ptr_w = row_weights_ptr + i
49+
ptr_i = row_indices_ptr + i
50+
51+
tl.store(ptr_w, prob)
52+
tl.store(ptr_i, idx)
53+
54+
values = tl.where(offsets == idx, -float("inf"), values)
55+
56+
if RENORM:
57+
sum_prob = tl.where(sum_prob < 1e-8, 1e-8, sum_prob)
58+
topk_offd = tl.arange(0, BLOCK_TOPK)
59+
topk_mask = topk_offd < top_k
60+
prob = tl.load(row_weights_ptr + topk_offd, mask=topk_mask, other=0.0)
61+
prob = prob / sum_prob
62+
tl.store(row_weights_ptr + topk_offd, prob, mask=topk_mask)
63+
return
64+
65+
66+
def softmax_topk(gating_output: torch.Tensor, topk: int, renorm: bool = False):
67+
assert gating_output.dim() == 2, "The dim of gating_output must be 2."
68+
num_tokens, num_experts = gating_output.shape
69+
device = gating_output.device
70+
71+
if gating_output.dtype != torch.float32:
72+
gating_output = gating_output.to(torch.float32)
73+
74+
topk_vals = torch.empty((num_tokens, topk), dtype=torch.float32, device=device)
75+
topk_idxs = torch.empty((num_tokens, topk), dtype=torch.int32, device=device)
76+
77+
BLOCK_SIZE = triton.next_power_of_2(num_experts)
78+
NEED_MASK = BLOCK_SIZE != num_experts
79+
80+
num_warps = min(max(1, (BLOCK_SIZE // 8 // 32)), 16)
81+
82+
grid = (num_tokens,)
83+
softmax_topk_kernel[grid](
84+
topk_vals,
85+
topk_idxs,
86+
gating_output,
87+
gating_output.stride(0),
88+
topk_vals.stride(0),
89+
topk_idxs.stride(0),
90+
num_tokens,
91+
num_experts,
92+
BLOCK_SIZE=BLOCK_SIZE,
93+
BLOCK_TOPK=triton.next_power_of_2(topk),
94+
top_k=topk,
95+
NEED_MASK=NEED_MASK,
96+
RENORM=renorm,
97+
num_warps=num_warps,
98+
)
99+
100+
return topk_vals, topk_idxs

lightllm/common/fused_moe/topk_select.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from lightllm.utils.sgl_utils import sgl_ops
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
25+
from lightllm.common.fused_moe.softmax_topk import softmax_topk
2526

2627
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
2728

@@ -33,11 +34,9 @@ def fused_topk(
3334
renormalize: bool,
3435
):
3536
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
36-
assert (
37-
sgl_ops is not None
38-
), "sgl_kernel is not installed, you can't use the cuda fused_topk. \
39-
You can solve it by running `pip install sgl_kernel`."
4037

38+
if sgl_ops is None:
39+
return softmax_topk(gating_output, topk, renorm=renormalize)
4140
M, _ = hidden_states.shape
4241

4342
topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device)
@@ -69,7 +68,6 @@ def grouped_topk(
6968
topk_group: int = 0,
7069
scoring_func: str = "softmax",
7170
):
72-
7371
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
7472
if scoring_func == "sigmoid":
7573
scores = torch.sigmoid(gating_output)
@@ -145,7 +143,6 @@ def cuda_grouped_topk(
145143
topk_group: int = 0,
146144
scoring_func: str = "softmax",
147145
):
148-
149146
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
150147
assert light_ops is not None, "lightllm_kernel is not installed."
151148

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import torch
2+
import time
3+
import pytest
4+
import numpy as np
5+
from lightllm.common.fused_moe.softmax_topk import softmax_topk
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
def benchmark(M, N, K, renorm, runs):
12+
import sgl_kernel as sgl_ops
13+
14+
gating = torch.randn(M, N, device="cuda", dtype=torch.float32)
15+
torch.cuda.synchronize()
16+
17+
# 1. SGL kernel
18+
sgl_vals = torch.empty((M, K), dtype=torch.float32, device="cuda")
19+
sgl_ids = torch.empty((M, K), dtype=torch.int32, device="cuda")
20+
# Warm-up
21+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
22+
torch.cuda.synchronize()
23+
start = torch.cuda.Event(True)
24+
end = torch.cuda.Event(True)
25+
start.record()
26+
for _ in range(runs):
27+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
28+
if renorm:
29+
sgl_vals.div_(sgl_vals.sum(-1, keepdim=True).clamp_min(1e-8))
30+
31+
end.record()
32+
torch.cuda.synchronize()
33+
t_sgl = start.elapsed_time(end) / runs
34+
35+
# 2. Triton kernel
36+
t0 = torch.cuda.Event(True)
37+
t1 = torch.cuda.Event(True)
38+
# Warm-up
39+
softmax_topk(gating, K)
40+
torch.cuda.synchronize()
41+
t0.record()
42+
for _ in range(runs):
43+
triton_vals, triton_ids = softmax_topk(gating, K, renorm)
44+
t1.record()
45+
torch.cuda.synchronize()
46+
t_triton = t0.elapsed_time(t1) / runs
47+
48+
# 3. Native PyTorch
49+
_ = torch.softmax(gating, dim=-1)
50+
_, _ = torch.topk(_, K, dim=-1)
51+
torch.cuda.synchronize()
52+
53+
start, end = torch.cuda.Event(True), torch.cuda.Event(True)
54+
start.record()
55+
for _ in range(runs):
56+
probs = torch.softmax(gating, dim=-1)
57+
torch_vals, torch_ids = torch.topk(probs, K, dim=-1)
58+
if renorm:
59+
torch_vals.div_(torch_vals.sum(-1, keepdim=True).clamp_min(1e-8))
60+
end.record()
61+
torch.cuda.synchronize()
62+
t_torch = start.elapsed_time(end) / runs
63+
64+
# Compare indices and weights
65+
# Count mismatches of ordered indices
66+
diff_sgl_triton_ids = (sgl_ids != triton_ids).sum().item()
67+
diff_torch_triton_ids = (torch_ids != triton_ids).sum().item()
68+
# Max absolute difference of weights aligned by position
69+
max_err_triton_torch = (triton_vals - torch_vals).abs().max().item()
70+
max_err_triton_torch_sgl = (sgl_vals - torch_vals).abs().max().item()
71+
max_err_triton_sgl = (triton_vals - sgl_vals).abs().max().item()
72+
73+
assert diff_sgl_triton_ids == 0, f"Mismatch SGL vs Triton ids: {diff_sgl_triton_ids}"
74+
assert diff_torch_triton_ids == 0, f"Mismatch Torch vs Triton ids: {diff_torch_triton_ids}"
75+
assert max_err_triton_torch < 1e-3, f"Max err Triton vs Torch: {max_err_triton_torch}"
76+
assert max_err_triton_torch_sgl < 1e-3, f"Max err Triton vs SGL: {max_err_triton_torch_sgl}"
77+
assert max_err_triton_sgl < 1e-3, f"Max err Torch vs SGL: {max_err_triton_sgl}"
78+
79+
results = {
80+
"time_sgl": t_sgl,
81+
"time_triton": t_triton,
82+
"time_torch": t_torch,
83+
"mismatch_sgl_triton_ids": diff_sgl_triton_ids,
84+
"mismatch_torch_triton_ids": diff_torch_triton_ids,
85+
"max_err_triton_torch": max_err_triton_torch,
86+
"max_err_triton_sgl": max_err_triton_sgl,
87+
"max_err_triton_torch_sgl": max_err_triton_torch_sgl,
88+
"sgl_ids": sgl_ids,
89+
"triton_ids": triton_ids,
90+
"torch_ids": torch_ids,
91+
"sgl_vals": sgl_vals,
92+
"triton_vals": triton_vals,
93+
"torch_vals": torch_vals,
94+
}
95+
return results
96+
97+
98+
def test_softmax_topk():
99+
M, N, K = 8192, 1024, 8
100+
res = benchmark(M, N, K, False, 1000)
101+
print(f"SGL time: {res['time_sgl']:.6f}ms")
102+
print(f"Triton time: {res['time_triton']:.6f}ms")
103+
print(f"PyTorch time: {res['time_torch']:.6f}ms")
104+
print("Mismatch SGL vs Triton ids:", res["mismatch_sgl_triton_ids"])
105+
print("Mismatch Torch vs Triton ids:", res["mismatch_torch_triton_ids"])
106+
print("Max err Triton vs Torch :", res["max_err_triton_torch"])
107+
print("Max err Triton vs SGL :", res["max_err_triton_sgl"])
108+
print("Max err Torch vs SGL :", res["max_err_triton_torch_sgl"])
109+
benchmark(M, N, K, True, 10)
110+
benchmark(M, 256, 5, True, 10)
111+
benchmark(M, 127, 5, True, 10)
112+
113+
114+
if __name__ == "__main__":
115+
pytest.main()

0 commit comments

Comments
 (0)