1
1
import torch
2
2
import triton
3
3
from dataclasses import dataclass , field
4
- from .routing_details ._routing_compute import _routing_memset_indx
5
- from .routing_details ._routing_compute import _routing_compute_indx_offs
6
- from .routing_details ._routing_compute import _routing_compute_indx
4
+ from .routing_details ._routing_compute import _combined_routing_compute
5
+ from .routing_details ._routing_compute import _combined_routing_memset
7
6
from .routing_details ._routing_compute import _routing_clear_bitmatrix
8
7
from .routing_details ._expt_data import _expt_data_memset
9
8
from .routing_details ._expt_data import _expt_data_compute
@@ -115,32 +114,42 @@ def forward(ctx, expt_scal, expt_indx, bitmatrix):
115
114
topk_indx = combined_indx [:n_gates_pad ]
116
115
gate_indx = combined_indx [n_gates_pad :]
117
116
gate_scal = torch .empty (n_gates_pad , dtype = dtype , device = device )
118
- _routing_memset_indx [(cdiv (n_gates_pad * 2 , MEMSET_BLOCK ) + 1 , )](
117
+
118
+ token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1a , blocks2a , MEMSET_BLOCK_A , HIST2_BLOCK_M , block_m_log2_start , block_m_num = _compute_expt_data_internal (
119
+ hist , n_expts_tot , n_gates_pad )
120
+
121
+ blocks1b = cdiv (n_gates_pad * 2 , MEMSET_BLOCK ) + n_expts_tot + 1
122
+ blocks2b = cdiv (n_tokens_pad , HIST_BLOCK_M )
123
+
124
+ _combined_routing_memset [(blocks1a + blocks1b , )](
119
125
combined_indx , n_gates_pad * 2 , - 1 , MEMSET_BLOCK , hist , #
120
- expt_offs , hist .shape [0 ], BLOCK_N = 512 #
121
- )
122
- _routing_compute_indx_offs [(n_expts_tot , )](
123
- expt_offs , partial_hist , # inputs
126
+ expt_offs , hist .shape [0 ], n_expts_tot , partial_hist , # inputs
124
127
partial_hist .shape [0 ], partial_hist .stride (0 ), partial_hist .stride (1 ), # outputs
125
- BLOCK_M = INDX_OFFS_BLOCK_M , # tunable parameters
128
+ token_offs_combined , token_offs_combined .stride (0 ), #
129
+ blocks1a , block_pid_map , #
130
+ block_m_log2_start , SIZES = block_m_num , BLOCK_A = MEMSET_BLOCK_A , # optimization parameters
131
+ BLOCK_N = 512 , BLOCK_M = INDX_OFFS_BLOCK_M , # tunable parameters
126
132
)
133
+
127
134
indx_offs = partial_hist
128
- _routing_compute_indx [(cdiv (n_tokens_pad , HIST_BLOCK_M ), )](
135
+
136
+ _combined_routing_compute [(blocks2a + blocks2b , )](
129
137
topk_indx , gate_indx , gate_scal , # outputs
130
138
expt_scal , expt_indx , indx_offs , indx_offs .stride (0 ), indx_offs .stride (1 ), # inputs
131
- n_tokens_pad , n_tokens_raw , # input shape
132
- BLOCK_M = HIST_BLOCK_M , # tunable parameters
133
- N_EXPTS_ACT = n_expts_act , # constants
134
- num_warps = 1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4 #
139
+ expt_offs , n_tokens_pad , n_tokens_raw , # input shape
140
+ HIST_BLOCK_M , n_expts_act , # constants
141
+ hist , token_offs_pad , token_offs_pad . stride ( 0 ), block_pid_map , block_pid_map . stride ( 0 ), # outputs
142
+ block_m_log2_start , block_m_num , HIST2_BLOCK_M , blocks2a , # etc.
135
143
)
144
+
136
145
ctx .n_tokens_raw = n_tokens_raw
137
146
ctx .n_tokens_pad = n_tokens_pad
138
147
ctx .n_expts_act = n_expts_act
139
148
ctx .save_for_backward (gate_indx )
140
- return hist , topk_indx , gate_indx , gate_scal
149
+ return hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map
141
150
142
151
@staticmethod
143
- def backward (ctx , _0 , _1 , _2 , dgate_scal ):
152
+ def backward (ctx , _0 , _1 , _2 , dgate_scal , _3 , _4 , _5 ):
144
153
(gate_indx , ) = ctx .saved_tensors
145
154
dgate_scal = dgate_scal [gate_indx ]
146
155
dgate_scal = dgate_scal .reshape (ctx .n_tokens_pad , ctx .n_expts_act )
@@ -193,16 +202,17 @@ def log2_power_of_two(x):
193
202
return x .bit_length () - 1
194
203
195
204
196
- def compute_expt_data (expt_hist , n_expts_tot , n_gates ):
197
- if expt_hist is None :
198
- return ExptData (None , None , None , None )
199
- MEMSET_BLOCK = 128
205
+ block_m_log2_start = 4
206
+
207
+
208
+ def _compute_expt_data_internal (expt_hist , n_expts_tot , n_gates ):
209
+
210
+ MEMSET_BLOCK = 512
200
211
HIST2_BLOCK_M = 512
201
212
device = expt_hist .device
202
213
n_expts_tot = n_expts_tot
203
214
cdiv = triton .cdiv
204
215
# block_ms are all powers-of-two between 16 and 128 (inclusive)
205
- block_m_log2_start = 4
206
216
block_m_log2_end = 9 if is_hip () else 8
207
217
block_m_num = block_m_log2_end - block_m_log2_start
208
218
if n_gates <= n_expts_tot :
@@ -212,26 +222,53 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
212
222
# allocate memory
213
223
pad = lambda x : cdiv (x , MEMSET_BLOCK ) * MEMSET_BLOCK
214
224
dtype = torch .int32
215
- token_offs_raw = torch .empty ((n_expts_tot + 1 , ), dtype = dtype , device = device )
216
- token_offs_pad = torch .empty ((block_m_num , pad (n_expts_tot + 1 )), dtype = dtype , device = device )
225
+
226
+ token_offs_combined = torch .empty ((block_m_num + 1 , pad (n_expts_tot + 1 )), dtype = dtype , device = device )
227
+
228
+ token_offs_raw = token_offs_combined [0 ][:n_expts_tot + 1 ]
229
+ token_offs_pad = token_offs_combined [1 :]
230
+
217
231
block_pid_map = torch .empty ((block_m_num , pad (max_n_tiles )), dtype = dtype , device = device )
232
+ memset_grid = torch .numel (block_pid_map ) // MEMSET_BLOCK # exact division
218
233
# compute outputs
219
234
token_offs_pad = token_offs_pad [:, :n_expts_tot + 1 ]
220
235
block_pid_map = block_pid_map [:, :max_n_tiles ]
221
- memset_grid = cdiv (block_pid_map .shape [1 ], MEMSET_BLOCK ) + 1
222
- _expt_data_memset [(memset_grid , block_m_num )](
223
- expt_hist , n_expts_tot , token_offs_raw , #
224
- token_offs_pad , token_offs_pad .stride (0 ), #
225
- block_pid_map , block_pid_map .stride (0 ), #
226
- block_m_log2_start , BLOCK = MEMSET_BLOCK , # optimization parameters
227
- num_warps = 1 )
228
- _expt_data_compute [(n_expts_tot , block_m_num )](
236
+
237
+ blocks1 = memset_grid + block_m_num + 1
238
+ blocks2 = n_expts_tot * block_m_num
239
+
240
+ return token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1 , blocks2 , MEMSET_BLOCK , HIST2_BLOCK_M , block_m_log2_start , block_m_num
241
+
242
+
243
+ def _unpack_into_dict (x ):
244
+
245
+ block_m_log2_end = block_m_log2_start + x .shape [0 ]
246
+ x = {2 ** j : x [i , :] for i , j in enumerate (range (block_m_log2_start , block_m_log2_end ))}
247
+ return x
248
+
249
+
250
+ def compute_expt_data (expt_hist , n_expts_tot , n_gates ):
251
+
252
+ if expt_hist is None :
253
+ return ExptData (None , None , None , None )
254
+
255
+ # this just computes the kernel arguments:
256
+ token_offs_combined , token_offs_raw , token_offs_pad , block_pid_map , blocks1 , blocks2 , MEMSET_BLOCK , HIST2_BLOCK_M , block_m_log2_start , block_m_num = _compute_expt_data_internal (
257
+ expt_hist , n_expts_tot , n_gates )
258
+
259
+ _expt_data_memset [(blocks1 , )](
260
+ expt_hist , n_expts_tot , #
261
+ token_offs_combined , token_offs_combined .stride (0 ), #
262
+ block_pid_map , #
263
+ block_m_log2_start , SIZES = block_m_num , BLOCK = MEMSET_BLOCK , # optimization parameters
264
+ num_warps = 4 )
265
+ _expt_data_compute [(blocks2 , )](
229
266
expt_hist , token_offs_pad , token_offs_pad .stride (0 ), block_pid_map , block_pid_map .stride (0 ), # outputs
230
- block_m_log2_start , BLOCK = HIST2_BLOCK_M , # optimization parameters
267
+ block_m_log2_start , SIZES = block_m_num , BLOCK = HIST2_BLOCK_M , # optimization parameters
231
268
num_warps = 4 )
232
- # unpack into datastructure
233
- token_offs_pad = { 2 ** j : token_offs_pad [ i , :] for i , j in enumerate ( range ( block_m_log2_start , block_m_log2_end ))}
234
- block_pid_map = { 2 ** j : block_pid_map [ i , :] for i , j in enumerate ( range ( block_m_log2_start , block_m_log2_end ))}
269
+
270
+ token_offs_pad = _unpack_into_dict ( token_offs_pad )
271
+ block_pid_map = _unpack_into_dict ( block_pid_map )
235
272
return ExptData (expt_hist , token_offs_raw , token_offs_pad , block_pid_map )
236
273
237
274
@@ -249,12 +286,18 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1,
249
286
# mutate bitmatrix
250
287
if simulated_ep > 1 :
251
288
expt_scal , expt_indx , bitmatrix = prune_routing (expt_scal , expt_indx , bitmatrix , simulated_ep )
252
- hist , topk_indx , gate_indx , gate_scal = sort_tokens (expt_scal , expt_indx , bitmatrix )
289
+ hist , topk_indx , gate_indx , gate_scal , token_offs_raw , token_offs_pad , block_pid_map = sort_tokens (
290
+ expt_scal , expt_indx , bitmatrix )
291
+
292
+ token_offs_pad = _unpack_into_dict (token_offs_pad )
293
+ block_pid_map = _unpack_into_dict (block_pid_map )
294
+ expt_data = ExptData (hist , token_offs_raw , token_offs_pad , block_pid_map )
295
+
253
296
# pack the matmul data structure
254
297
n_expts_tot = logits .shape [- 1 ] // simulated_ep
255
298
gather_indx = GatherIndx (src_indx = topk_indx , dst_indx = gate_indx )
256
299
scatter_indx = ScatterIndx (src_indx = gate_indx , dst_indx = topk_indx )
257
- expt_data = compute_expt_data ( hist , n_expts_tot , topk_indx . numel ())
300
+
258
301
return RoutingData (gate_scal , hist , n_expts_tot , n_expts_act , expt_data ), gather_indx , scatter_indx
259
302
260
303
0 commit comments