@@ -276,6 +276,19 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
276
276
# --------------------------
277
277
278
278
279
+ def routing_from_bitmatrix (bitmatrix , expt_scal , expt_indx , n_expts_tot , n_expts_act ):
280
+ hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map = sort_tokens (
281
+ expt_scal , expt_indx , n_expts_tot , bitmatrix )
282
+ token_offs_pad = _unpack_into_dict (token_offs_pad )
283
+ block_pid_map = _unpack_into_dict (block_pid_map )
284
+ expt_data = ExptData (hist , token_offs_raw , token_offs_pad , block_pid_map )
285
+
286
+ # pack the matmul data structure
287
+ gather_indx = GatherIndx (src_indx = topk_indx , dst_indx = gate_indx )
288
+ scatter_indx = ScatterIndx (src_indx = gate_indx , dst_indx = topk_indx )
289
+ return RoutingData (gate_scal , hist , n_expts_tot , n_expts_act , expt_data ), gather_indx , scatter_indx
290
+
291
+
279
292
def routing (logits , n_expts_act , sm_first = False , expt_indx = None , simulated_ep = 1 , n_rows = None ):
280
293
from .topk import topk
281
294
if sm_first :
@@ -286,16 +299,8 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1,
286
299
# mutate bitmatrix
287
300
if simulated_ep > 1 :
288
301
expt_scal , expt_indx , bitmatrix = prune_routing (expt_scal , expt_indx , bitmatrix , logits .shape [- 1 ], simulated_ep )
289
- hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map = sort_tokens (
290
- expt_scal , expt_indx , n_expts_tot , bitmatrix )
291
- token_offs_pad = _unpack_into_dict (token_offs_pad )
292
- block_pid_map = _unpack_into_dict (block_pid_map )
293
- expt_data = ExptData (hist , token_offs_raw , token_offs_pad , block_pid_map )
294
302
295
- # pack the matmul data structure
296
- gather_indx = GatherIndx (src_indx = topk_indx , dst_indx = gate_indx )
297
- scatter_indx = ScatterIndx (src_indx = gate_indx , dst_indx = topk_indx )
298
- return RoutingData (gate_scal , hist , n_expts_tot , n_expts_act , expt_data ), gather_indx , scatter_indx
303
+ return routing_from_bitmatrix (bitmatrix , expt_scal , expt_indx , n_expts_tot , n_expts_act )
299
304
300
305
301
306
# --------------------------
0 commit comments