Skip to content

Commit 9853a71

Browse files
authored
[BENCH] [NFC] Update routing.py (#7503)
Minor refactor
1 parent 0560390 commit 9853a71

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

python/triton_kernels/triton_kernels/routing.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,19 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
276276
# --------------------------
277277

278278

279+
def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
280+
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
281+
expt_scal, expt_indx, n_expts_tot, bitmatrix)
282+
token_offs_pad = _unpack_into_dict(token_offs_pad)
283+
block_pid_map = _unpack_into_dict(block_pid_map)
284+
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
285+
286+
# pack the matmul data structure
287+
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
288+
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
289+
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
290+
291+
279292
def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None):
280293
from .topk import topk
281294
if sm_first:
@@ -286,16 +299,8 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1,
286299
# mutate bitmatrix
287300
if simulated_ep > 1:
288301
expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep)
289-
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
290-
expt_scal, expt_indx, n_expts_tot, bitmatrix)
291-
token_offs_pad = _unpack_into_dict(token_offs_pad)
292-
block_pid_map = _unpack_into_dict(block_pid_map)
293-
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
294302

295-
# pack the matmul data structure
296-
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
297-
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
298-
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
303+
return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act)
299304

300305

301306
# --------------------------

0 commit comments

Comments
 (0)