Skip to content

Commit 3463ab7

Browse files
authored
[Python] Update top-k kernel with TIR meta-programming (#3079)
This PR introduces the meta-programming top-k kernel for expert selection in MoE. With this PR, we can deprecate the previous ad-hoc kernels specialized for k=2 and k=4.
1 parent 1825fed commit 3463ab7

File tree

1 file changed

+75
-150
lines changed

1 file changed

+75
-150
lines changed

python/mlc_llm/op/moe_misc.py

Lines changed: 75 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -56,162 +56,87 @@ def gating_softmax_topk( # pylint: disable=too-many-statements
5656
index_dtype = "int32"
5757

5858
TX = 1024
59-
SCAN_LEN_2 = 2
60-
SCAN_LEN_4 = 4
6159

62-
# specialized kernel for top 2 case
63-
@T.prim_func(private=True)
64-
def top2_softmax_norm_func(
65-
var_x: T.handle,
66-
var_out: T.handle,
67-
var_out_index: T.handle,
68-
) -> None:
69-
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
70-
batch_size = T.int64()
71-
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
72-
out = T.match_buffer(var_out, (batch_size, SCAN_LEN_2), dtype)
73-
out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_2), index_dtype)
74-
local_top_k = T.alloc_buffer((SCAN_LEN_2,), dtype=dtype, scope="local")
75-
local_top_k_index = T.alloc_buffer((SCAN_LEN_2,), dtype=index_dtype, scope="local")
76-
local_top_k_f32 = T.alloc_buffer((SCAN_LEN_2,), dtype="float32", scope="local")
77-
local_top_k_max = T.alloc_buffer((1,), dtype="float32", scope="local")
78-
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
79-
for ii in T.thread_binding(0, TX, "threadIdx.x"):
80-
with T.block("top_k"):
81-
vi = T.axis.spatial(batch_size, io * TX + ii)
82-
T.where(io * TX + ii < batch_size)
83-
with T.block("init"):
84-
local_top_k[0] = T.min_value(dtype)
85-
local_top_k[1] = T.min_value(dtype)
86-
local_top_k_index[0] = 0
87-
local_top_k_index[1] = 1
88-
for k in range(num_local_experts):
89-
with T.block("update"):
90-
vk = T.axis.remap("S", [k])
91-
# N.B. This snippet is specialized for k = 2
92-
if x[vi, vk] > local_top_k[0]:
93-
local_top_k[1] = local_top_k[0]
94-
local_top_k_index[1] = local_top_k_index[0]
95-
local_top_k[0] = x[vi, vk]
96-
local_top_k_index[0] = vk
97-
elif x[vi, vk] > local_top_k[1]:
98-
local_top_k[1] = x[vi, vk]
99-
local_top_k_index[1] = vk
100-
for j in T.unroll(SCAN_LEN_2):
101-
with T.block("cast"):
102-
vj = T.axis.remap("S", [j])
103-
local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32")
104-
with T.block("max"):
105-
local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1])
106-
for j in T.unroll(SCAN_LEN_2):
107-
with T.block("output"):
108-
vj = T.axis.remap("S", [j])
109-
out[vi, vj] = T.cast(
110-
T.exp(local_top_k_f32[vj] - local_top_k_max[0])
111-
/ (
112-
T.exp(local_top_k_f32[0] - local_top_k_max[0])
113-
+ T.exp(local_top_k_f32[1] - local_top_k_max[0])
114-
),
115-
dtype,
116-
)
117-
out_index[vi, vj] = local_top_k_index[vj]
118-
119-
# specialized kernel for top 4 case
120-
@T.prim_func(private=True)
121-
def top4_softmax_norm_func(
122-
var_x: T.handle,
123-
var_out: T.handle,
124-
var_out_index: T.handle,
125-
) -> None:
126-
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
127-
batch_size = T.int64()
128-
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
129-
out = T.match_buffer(var_out, (batch_size, SCAN_LEN_4), dtype)
130-
out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_4), index_dtype)
131-
local_top_k = T.alloc_buffer((SCAN_LEN_4,), dtype=dtype, scope="local")
132-
local_top_k_index = T.alloc_buffer((SCAN_LEN_4,), dtype=index_dtype, scope="local")
133-
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
134-
for ii in T.thread_binding(0, TX, "threadIdx.x"):
135-
with T.block("top_k"):
136-
vi = T.axis.spatial(batch_size, io * TX + ii)
137-
T.where(io * TX + ii < batch_size)
138-
with T.block("init"):
139-
local_top_k[0] = T.min_value(dtype)
140-
local_top_k[1] = T.min_value(dtype)
141-
local_top_k[2] = T.min_value(dtype)
142-
local_top_k[3] = T.min_value(dtype)
143-
local_top_k_index[0] = 0
144-
local_top_k_index[1] = 1
145-
local_top_k_index[2] = 2
146-
local_top_k_index[3] = 3
147-
for k in range(num_local_experts):
148-
with T.block("update"):
149-
vk = T.axis.remap("S", [k])
150-
# N.B. This snippet is specialized for k = 4
151-
if x[vi, vk] > local_top_k[0]:
152-
local_top_k[3] = local_top_k[2]
153-
local_top_k_index[3] = local_top_k_index[2]
154-
local_top_k[2] = local_top_k[1]
155-
local_top_k_index[2] = local_top_k_index[1]
156-
local_top_k[1] = local_top_k[0]
157-
local_top_k_index[1] = local_top_k_index[0]
158-
local_top_k[0] = x[vi, vk]
159-
local_top_k_index[0] = vk
160-
elif x[vi, vk] > local_top_k[1]:
161-
local_top_k[3] = local_top_k[2]
162-
local_top_k_index[3] = local_top_k_index[2]
163-
local_top_k[2] = local_top_k[1]
164-
local_top_k_index[2] = local_top_k_index[1]
165-
local_top_k[1] = x[vi, vk]
166-
local_top_k_index[1] = vk
167-
elif x[vi, vk] > local_top_k[2]:
168-
local_top_k[3] = local_top_k[2]
169-
local_top_k_index[3] = local_top_k_index[2]
170-
local_top_k[2] = x[vi, vk]
171-
local_top_k_index[2] = vk
172-
elif x[vi, vk] > local_top_k[3]:
173-
local_top_k[3] = x[vi, vk]
174-
local_top_k_index[3] = vk
175-
for j in T.unroll(SCAN_LEN_4):
176-
with T.block("output"):
177-
vj = T.axis.remap("S", [j])
178-
out[vi, vj] = local_top_k[vj]
179-
out_index[vi, vj] = local_top_k_index[vj]
180-
181-
# fast path for Mixtral
182-
if k == 2 and norm_topk_prob:
60+
def _get_topk_softmax_norm_func(k_val: int):
61+
def _init_local_top_k(local_top_k, local_top_k_index):
62+
for t in range(k_val):
63+
T.buffer_store(local_top_k, T.min_value(dtype), indices=[t])
64+
for t in range(k_val):
65+
T.buffer_store(local_top_k_index, t, indices=[t])
66+
67+
def _process_value(x, local_top_k, local_top_k_index, vi, vk):
68+
if_frames = [T.If(x[vi, vk] > local_top_k[i]) for i in range(k_val)]
69+
then_frames = [T.Then() for _ in range(k_val)]
70+
else_frames = [T.Else() for _ in range(k_val - 1)]
71+
for i in range(k_val):
72+
if_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call
73+
with then_frames[i]:
74+
for j in range(k_val - 1, i, -1):
75+
T.buffer_store(local_top_k, local_top_k[j - 1], indices=[j])
76+
T.buffer_store(local_top_k_index, local_top_k_index[j - 1], indices=[j])
77+
T.buffer_store(local_top_k, x[vi, vk], indices=[i])
78+
T.buffer_store(local_top_k_index, vk, indices=[i])
79+
if i != k_val - 1:
80+
else_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call
81+
82+
for i in range(k_val - 1, -1, -1):
83+
if i != k_val - 1:
84+
else_frames[i].__exit__(None, None, None)
85+
if_frames[i].__exit__(None, None, None)
86+
87+
@T.prim_func(private=True)
88+
def topk_softmax_norm_func(
89+
var_x: T.handle,
90+
var_out: T.handle,
91+
var_out_index: T.handle,
92+
) -> None:
93+
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
94+
batch_size = T.int64()
95+
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
96+
out = T.match_buffer(var_out, (batch_size, k_val), dtype)
97+
out_index = T.match_buffer(var_out_index, (batch_size, k_val), index_dtype)
98+
local_top_k = T.alloc_buffer((k_val,), dtype=dtype, scope="local")
99+
local_top_k_index = T.alloc_buffer((k_val,), dtype=index_dtype, scope="local")
100+
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
101+
for ii in T.thread_binding(0, TX, "threadIdx.x"):
102+
with T.block("top_k"):
103+
vi = T.axis.spatial(batch_size, io * TX + ii)
104+
T.where(io * TX + ii < batch_size)
105+
with T.block("init"):
106+
_init_local_top_k(local_top_k, local_top_k_index)
107+
for k in range(num_local_experts):
108+
with T.block("update"):
109+
vk = T.axis.remap("S", [k])
110+
_process_value(x, local_top_k, local_top_k_index, vi, vk)
111+
for j in T.unroll(k_val):
112+
with T.block("output"):
113+
vj = T.axis.remap("S", [j])
114+
out[vi, vj] = local_top_k[vj]
115+
out_index[vi, vj] = local_top_k_index[vj]
116+
117+
return topk_softmax_norm_func
118+
119+
if norm_topk_prob:
183120
return op.tensor_ir_op(
184-
top2_softmax_norm_func,
185-
"top2_softmax",
121+
_get_topk_softmax_norm_func(k),
122+
f"top{k}_softmax",
186123
args=[x],
187124
out=(
188-
Tensor.placeholder([batch_size, 2], dtype),
189-
Tensor.placeholder([batch_size, 2], index_dtype),
190-
),
191-
)
192-
if k == 4 and not norm_topk_prob:
193-
expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
194-
return op.tensor_ir_op(
195-
top4_softmax_norm_func,
196-
"top4_softmax",
197-
args=[expert_score],
198-
out=(
199-
Tensor.placeholder([batch_size, 4], dtype),
200-
Tensor.placeholder([batch_size, 4], index_dtype),
125+
Tensor.placeholder([batch_size, k], dtype),
126+
Tensor.placeholder([batch_size, k], index_dtype),
201127
),
202128
)
203-
if norm_topk_prob:
204-
# Compute topk first and then softmax to avoid extra re-normalize
205-
expert_score, expert_indices = op.topk(
206-
x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype
207-
)
208-
expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype)
209-
else:
210-
expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
211-
expert_score, expert_indices = op.topk(
212-
expert_score, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype
213-
)
214-
return expert_score, expert_indices
129+
130+
expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
131+
return op.tensor_ir_op(
132+
_get_topk_softmax_norm_func(k),
133+
f"top{k}_softmax",
134+
args=[expert_score],
135+
out=(
136+
Tensor.placeholder([batch_size, k], dtype),
137+
Tensor.placeholder([batch_size, k], index_dtype),
138+
),
139+
)
215140

216141

217142
def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor:

0 commit comments

Comments
 (0)