11import torch
22import triton
33from dataclasses import dataclass , field
4- from .routing_details ._routing_compute import _routing_memset_indx
5- from .routing_details ._routing_compute import _routing_compute_indx_offs
6- from .routing_details ._routing_compute import _routing_compute_indx
4+ from .routing_details ._routing_compute import _combined_routing_compute
5+ from .routing_details ._routing_compute import _combined_routing_memset
76from .routing_details ._routing_compute import _routing_clear_bitmatrix
87from .routing_details ._expt_data import _expt_data_memset
98from .routing_details ._expt_data import _expt_data_compute
@@ -115,32 +114,42 @@ def forward(ctx, expt_scal, expt_indx, bitmatrix):
115114 topk_indx = combined_indx [:n_gates_pad ]
116115 gate_indx = combined_indx [n_gates_pad :]
117116 gate_scal = torch .empty (n_gates_pad , dtype = dtype , device = device )
118- _routing_memset_indx [(cdiv (n_gates_pad * 2 , MEMSET_BLOCK ) + 1 , )](
117+
118+ token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1a , blocks2a , MEMSET_BLOCK_A , HIST2_BLOCK_M , block_m_log2_start , block_m_num = _compute_expt_data_internal (
119+ hist , n_expts_tot , n_gates_pad )
120+
121+ blocks1b = cdiv (n_gates_pad * 2 , MEMSET_BLOCK ) + n_expts_tot + 1
122+ blocks2b = cdiv (n_tokens_pad , HIST_BLOCK_M )
123+
124+ _combined_routing_memset [(blocks1a + blocks1b , )](
119125 combined_indx , n_gates_pad * 2 , - 1 , MEMSET_BLOCK , hist , #
120- expt_offs , hist .shape [0 ], BLOCK_N = 512 #
121- )
122- _routing_compute_indx_offs [(n_expts_tot , )](
123- expt_offs , partial_hist , # inputs
126+ expt_offs , hist .shape [0 ], n_expts_tot , partial_hist , # inputs
124127 partial_hist .shape [0 ], partial_hist .stride (0 ), partial_hist .stride (1 ), # outputs
125- BLOCK_M = INDX_OFFS_BLOCK_M , # tunable parameters
128+ token_offs_combined , token_offs_combined .stride (0 ), #
129+ blocks1a , block_pid_map , #
130+ block_m_log2_start , SIZES = block_m_num , BLOCK_A = MEMSET_BLOCK_A , # optimization parameters
131+ BLOCK_N = 512 , BLOCK_M = INDX_OFFS_BLOCK_M , # tunable parameters
126132 )
133+
127134 indx_offs = partial_hist
128- _routing_compute_indx [(cdiv (n_tokens_pad , HIST_BLOCK_M ), )](
135+
136+ _combined_routing_compute [(blocks2a + blocks2b , )](
129137 topk_indx , gate_indx , gate_scal , # outputs
130138 expt_scal , expt_indx , indx_offs , indx_offs .stride (0 ), indx_offs .stride (1 ), # inputs
131- n_tokens_pad , n_tokens_raw , # input shape
132- BLOCK_M = HIST_BLOCK_M , # tunable parameters
133- N_EXPTS_ACT = n_expts_act , # constants
134- num_warps = 1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4 #
139+ expt_offs , n_tokens_pad , n_tokens_raw , # input shape
140+ HIST_BLOCK_M , n_expts_act , # constants
141+ hist , token_offs_pad , token_offs_pad . stride ( 0 ), block_pid_map , block_pid_map . stride ( 0 ), # outputs
142+ block_m_log2_start , block_m_num , HIST2_BLOCK_M , blocks2a , # etc.
135143 )
144+
136145 ctx .n_tokens_raw = n_tokens_raw
137146 ctx .n_tokens_pad = n_tokens_pad
138147 ctx .n_expts_act = n_expts_act
139148 ctx .save_for_backward (gate_indx )
140- return hist , topk_indx , gate_indx , gate_scal
149+ return hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map
141150
142151 @staticmethod
143- def backward (ctx , _0 , _1 , _2 , dgate_scal ):
152+ def backward (ctx , _0 , _1 , _2 , dgate_scal , _3 , _4 , _5 ):
144153 (gate_indx , ) = ctx .saved_tensors
145154 dgate_scal = dgate_scal [gate_indx ]
146155 dgate_scal = dgate_scal .reshape (ctx .n_tokens_pad , ctx .n_expts_act )
@@ -193,16 +202,17 @@ def log2_power_of_two(x):
193202 return x .bit_length () - 1
194203
195204
196- def compute_expt_data (expt_hist , n_expts_tot , n_gates ):
197- if expt_hist is None :
198- return ExptData (None , None , None , None )
199- MEMSET_BLOCK = 128
205+ block_m_log2_start = 4
206+
207+
208+ def _compute_expt_data_internal (expt_hist , n_expts_tot , n_gates ):
209+
210+ MEMSET_BLOCK = 512
200211 HIST2_BLOCK_M = 512
201212 device = expt_hist .device
202213 n_expts_tot = n_expts_tot
203214 cdiv = triton .cdiv
204215 # block_ms are all powers-of-two between 16 and 128 (inclusive)
205- block_m_log2_start = 4
206216 block_m_log2_end = 9 if is_hip () else 8
207217 block_m_num = block_m_log2_end - block_m_log2_start
208218 if n_gates <= n_expts_tot :
@@ -212,26 +222,53 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
212222 # allocate memory
213223 pad = lambda x : cdiv (x , MEMSET_BLOCK ) * MEMSET_BLOCK
214224 dtype = torch .int32
215- token_offs_raw = torch .empty ((n_expts_tot + 1 , ), dtype = dtype , device = device )
216- token_offs_pad = torch .empty ((block_m_num , pad (n_expts_tot + 1 )), dtype = dtype , device = device )
225+
226+ token_offs_combined = torch .empty ((block_m_num + 1 , pad (n_expts_tot + 1 )), dtype = dtype , device = device )
227+
228+ token_offs_raw = token_offs_combined [0 ][:n_expts_tot + 1 ]
229+ token_offs_pad = token_offs_combined [1 :]
230+
217231 block_pid_map = torch .empty ((block_m_num , pad (max_n_tiles )), dtype = dtype , device = device )
232+ memset_grid = torch .numel (block_pid_map ) // MEMSET_BLOCK # exact division
218233 # compute outputs
219234 token_offs_pad = token_offs_pad [:, :n_expts_tot + 1 ]
220235 block_pid_map = block_pid_map [:, :max_n_tiles ]
221- memset_grid = cdiv (block_pid_map .shape [1 ], MEMSET_BLOCK ) + 1
222- _expt_data_memset [(memset_grid , block_m_num )](
223- expt_hist , n_expts_tot , token_offs_raw , #
224- token_offs_pad , token_offs_pad .stride (0 ), #
225- block_pid_map , block_pid_map .stride (0 ), #
226- block_m_log2_start , BLOCK = MEMSET_BLOCK , # optimization parameters
227- num_warps = 1 )
228- _expt_data_compute [(n_expts_tot , block_m_num )](
236+
237+ blocks1 = memset_grid + block_m_num + 1
238+ blocks2 = n_expts_tot * block_m_num
239+
240+ return token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1 , blocks2 , MEMSET_BLOCK , HIST2_BLOCK_M , block_m_log2_start , block_m_num
241+
242+
243+ def _unpack_into_dict (x ):
244+
245+ block_m_log2_end = block_m_log2_start + x .shape [0 ]
246+ x = {2 ** j : x [i , :] for i , j in enumerate (range (block_m_log2_start , block_m_log2_end ))}
247+ return x
248+
249+
250+ def compute_expt_data (expt_hist , n_expts_tot , n_gates ):
251+
252+ if expt_hist is None :
253+ return ExptData (None , None , None , None )
254+
255+ # this just computes the kernel arguments:
256+ token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1 , blocks2 , MEMSET_BLOCK , HIST2_BLOCK_M , block_m_log2_start , block_m_num = _compute_expt_data_internal (
257+ expt_hist , n_expts_tot , n_gates )
258+
259+ _expt_data_memset [(blocks1 , )](
260+ expt_hist , n_expts_tot , #
261+ token_offs_combined , token_offs_combined .stride (0 ), #
262+ block_pid_map , #
263+ block_m_log2_start , SIZES = block_m_num , BLOCK = MEMSET_BLOCK , # optimization parameters
264+ num_warps = 4 )
265+ _expt_data_compute [(blocks2 , )](
229266 expt_hist , token_offs_pad , token_offs_pad .stride (0 ), block_pid_map , block_pid_map .stride (0 ), # outputs
230- block_m_log2_start , BLOCK = HIST2_BLOCK_M , # optimization parameters
267+ block_m_log2_start , SIZES = block_m_num , BLOCK = HIST2_BLOCK_M , # optimization parameters
231268 num_warps = 4 )
232- # unpack into datastructure
233- token_offs_pad = { 2 ** j : token_offs_pad [ i , :] for i , j in enumerate ( range ( block_m_log2_start , block_m_log2_end ))}
234- block_pid_map = { 2 ** j : block_pid_map [ i , :] for i , j in enumerate ( range ( block_m_log2_start , block_m_log2_end ))}
269+
270+ token_offs_pad = _unpack_into_dict ( token_offs_pad )
271+ block_pid_map = _unpack_into_dict ( block_pid_map )
235272 return ExptData (expt_hist , token_offs_raw , token_offs_pad , block_pid_map )
236273
237274
@@ -249,12 +286,18 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1,
249286 # mutate bitmatrix
250287 if simulated_ep > 1 :
251288 expt_scal , expt_indx , bitmatrix = prune_routing (expt_scal , expt_indx , bitmatrix , simulated_ep )
252- hist , topk_indx , gate_indx , gate_scal = sort_tokens (expt_scal , expt_indx , bitmatrix )
289+ hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map = sort_tokens (
290+ expt_scal , expt_indx , bitmatrix )
291+
292+ token_offs_pad = _unpack_into_dict (token_offs_pad )
293+ block_pid_map = _unpack_into_dict (block_pid_map )
294+ expt_data = ExptData (hist , token_offs_raw , token_offs_pad , block_pid_map )
295+
253296 # pack the matmul data structure
254297 n_expts_tot = logits .shape [- 1 ] // simulated_ep
255298 gather_indx = GatherIndx (src_indx = topk_indx , dst_indx = gate_indx )
256299 scatter_indx = ScatterIndx (src_indx = gate_indx , dst_indx = topk_indx )
257- expt_data = compute_expt_data ( hist , n_expts_tot , topk_indx . numel ())
300+
258301 return RoutingData (gate_scal , hist , n_expts_tot , n_expts_act , expt_data ), gather_indx , scatter_indx
259302
260303
0 commit comments