Skip to content

Commit 8203497

Browse files
committed
add triton_softmax_topk
1 parent 2d4f7d4 commit 8203497

File tree

2 files changed

+163
-19
lines changed

2 files changed

+163
-19
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def softmax_topk_kernel(
8+
topk_weights_ptr,
9+
topk_indices_ptr,
10+
gating_output_ptr,
11+
input_row_stride,
12+
output_weights_row_stride,
13+
output_indices_row_stride,
14+
n_rows,
15+
n_cols,
16+
BLOCK_SIZE: tl.constexpr,
17+
top_k: tl.constexpr,
18+
):
19+
row_idx = tl.program_id(0)
20+
21+
row_input_ptr = gating_output_ptr + row_idx * input_row_stride
22+
row_weights_ptr = topk_weights_ptr + row_idx * output_weights_row_stride
23+
row_indices_ptr = topk_indices_ptr + row_idx * output_indices_row_stride
24+
25+
offsets = tl.arange(0, BLOCK_SIZE)
26+
mask = offsets < n_cols
27+
28+
values = tl.load(row_input_ptr + offsets, mask=mask, other=-float("inf"))
29+
30+
current_max = tl.max(values, axis=0)
31+
values = values - current_max
32+
numerators = tl.exp(values)
33+
denom = tl.sum(numerators, axis=0)
34+
35+
for i in range(top_k):
36+
logit = tl.max(values, axis=0)
37+
idx = tl.argmax(values, axis=0)
38+
39+
prob = tl.exp(logit) / denom
40+
41+
lane0 = offsets == 0
42+
ptr_w = row_weights_ptr + i + offsets * 0
43+
ptr_i = row_indices_ptr + i + offsets * 0
44+
tl.store(ptr_w, tl.where(lane0, prob, 0.0), mask=lane0)
45+
tl.store(ptr_i, tl.where(lane0, idx, 0), mask=lane0)
46+
47+
values = tl.where(offsets == idx, -float("inf"), values)
48+
49+
50+
def softmax_topk(gating_output: torch.Tensor, topk: int):
51+
assert gating_output.dim() == 2, "The dim of gating_output must be 2."
52+
num_tokens, num_experts = gating_output.shape
53+
device = gating_output.device
54+
55+
if gating_output.dtype != torch.float32:
56+
gating_output = gating_output.to(torch.float32)
57+
58+
topk_vals = torch.empty((num_tokens, topk), dtype=torch.float32, device=device)
59+
topk_idxs = torch.empty((num_tokens, topk), dtype=torch.int32, device=device)
60+
61+
BLOCK_SIZE = triton.next_power_of_2(num_experts)
62+
63+
grid = (num_tokens,)
64+
softmax_topk_kernel[grid](
65+
topk_vals,
66+
topk_idxs,
67+
gating_output,
68+
gating_output.stride(0),
69+
topk_vals.stride(0),
70+
topk_idxs.stride(0),
71+
num_tokens,
72+
num_experts,
73+
BLOCK_SIZE=BLOCK_SIZE,
74+
top_k=topk,
75+
num_warps=8,
76+
)
77+
return topk_vals, topk_idxs
78+
79+
80+
import sgl_kernel as sgl_ops
81+
82+
83+
#
84+
def benchmark(M, N, K):
85+
gating = torch.randn(M, N, device="cuda", dtype=torch.float32)
86+
torch.cuda.synchronize()
87+
88+
# 1. SGL kernel
89+
sgl_vals = torch.empty((M, K), dtype=torch.float32, device="cuda")
90+
sgl_ids = torch.empty((M, K), dtype=torch.int32, device="cuda")
91+
# Warm-up
92+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
93+
torch.cuda.synchronize()
94+
start = torch.cuda.Event(True)
95+
end = torch.cuda.Event(True)
96+
start.record()
97+
sgl_ops.topk_softmax(sgl_vals, sgl_ids, torch.empty_like(sgl_ids), gating)
98+
end.record()
99+
torch.cuda.synchronize()
100+
t_sgl = start.elapsed_time(end) / 1000.0
101+
102+
# 2. Triton kernel
103+
t0 = torch.cuda.Event(True)
104+
t1 = torch.cuda.Event(True)
105+
# Warm-up
106+
softmax_topk(gating, K)
107+
t0.record()
108+
triton_vals, triton_ids = softmax_topk(gating, K)
109+
t1.record()
110+
torch.cuda.synchronize()
111+
t_triton = t0.elapsed_time(t1) / 1000.0
112+
113+
# 3. Native PyTorch
114+
start, end = torch.cuda.Event(True), torch.cuda.Event(True)
115+
start.record()
116+
probs = torch.softmax(gating, dim=-1)
117+
torch_vals, torch_ids = torch.topk(probs, K, dim=-1)
118+
end.record()
119+
torch.cuda.synchronize()
120+
t_torch = start.elapsed_time(end) / 1000.0
121+
122+
# Compare indices and weights
123+
# Count mismatches of ordered indices
124+
diff_sgl_triton_ids = (sgl_ids != triton_ids).sum().item()
125+
diff_torch_triton_ids = (torch_ids != triton_ids).sum().item()
126+
# Max absolute difference of weights aligned by position
127+
max_err_triton_torch = (triton_vals - torch_vals).abs().max().item()
128+
max_err_triton_torch_sgl = (sgl_vals - torch_vals).abs().max().item()
129+
max_err_triton_sgl = (triton_vals - sgl_vals).abs().max().item()
130+
131+
results = {
132+
"time_sgl": t_sgl,
133+
"time_triton": t_triton,
134+
"time_torch": t_torch,
135+
"mismatch_sgl_triton_ids": diff_sgl_triton_ids,
136+
"mismatch_torch_triton_ids": diff_torch_triton_ids,
137+
"max_err_triton_torch": max_err_triton_torch,
138+
"max_err_triton_sgl": max_err_triton_sgl,
139+
"max_err_triton_torch_sgl": max_err_triton_torch_sgl,
140+
"sgl_ids": sgl_ids,
141+
"triton_ids": triton_ids,
142+
"torch_ids": torch_ids,
143+
"sgl_vals": sgl_vals,
144+
"triton_vals": triton_vals,
145+
"torch_vals": torch_vals,
146+
}
147+
return results
148+
149+
150+
if __name__ == "__main__":
151+
# Example: 8192 tokens, 1024 experts, Top-4
152+
M, N, K = 8192, 1024, 4
153+
res = benchmark(M, N, K)
154+
print(f"SGL time: {res['time_sgl']:.6f}s")
155+
print(f"Triton time: {res['time_triton']:.6f}s")
156+
print(f"PyTorch time: {res['time_torch']:.6f}s")
157+
print("Mismatch SGL vs Triton ids:", res["mismatch_sgl_triton_ids"])
158+
print("Mismatch Torch vs Triton ids:", res["mismatch_torch_triton_ids"])
159+
print("Max err Triton vs Torch :", res["max_err_triton_torch"])
160+
print("Max err Triton vs SGL :", res["max_err_triton_sgl"])
161+
print("Max err Torch vs SGL :", res["max_err_triton_torch_sgl"])

lightllm/common/fused_moe/topk_select.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from lightllm.utils.sgl_utils import sgl_ops
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
25+
from lightllm.common.fused_moe.softmax_topk import softmax_topk
2526

2627
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
2728

@@ -33,24 +34,8 @@ def fused_topk(
3334
renormalize: bool,
3435
):
3536
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
36-
assert (
37-
sgl_ops is not None
38-
), "sgl_kernel is not installed, you can't use the cuda fused_topk. \
39-
You can solve it by running `pip install sgl_kernel`."
4037

41-
M, _ = hidden_states.shape
42-
43-
topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device)
44-
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
45-
token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
46-
47-
sgl_ops.topk_softmax(
48-
topk_weights,
49-
topk_ids,
50-
token_expert_indicies,
51-
gating_output.float(), # TODO(woosuk): Optimize this.
52-
)
53-
del token_expert_indicies # Not used. Will be used in the future.
38+
topk_weights, topk_ids = softmax_topk(gating_output, topk)
5439

5540
if renormalize:
5641
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -69,7 +54,6 @@ def grouped_topk(
6954
topk_group: int = 0,
7055
scoring_func: str = "softmax",
7156
):
72-
7357
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
7458
if scoring_func == "sigmoid":
7559
scores = torch.sigmoid(gating_output)
@@ -145,7 +129,6 @@ def cuda_grouped_topk(
145129
topk_group: int = 0,
146130
scoring_func: str = "softmax",
147131
):
148-
149132
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
150133
assert light_ops is not None, "lightllm_kernel is not installed."
151134

0 commit comments

Comments
 (0)