@@ -150,64 +150,6 @@ __launch_bounds__(TPB) __global__
150
150
}
151
151
}
152
152
153
- template <typename T, int TPB, typename IdxT = int >
154
- __launch_bounds__ (TPB) __global__ void moe_top_k (const T* inputs_after_softmax,
155
- T* output,
156
- IdxT* indices,
157
- int * source_rows,
158
- T* softmax_max_prob,
159
- const int64_t num_experts,
160
- const int64_t k,
161
- const int64_t num_rows) {
162
- using cub_kvp = cub::KeyValuePair<int , T>;
163
- using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
164
- __shared__ typename BlockReduce::TempStorage tmpStorage;
165
-
166
- cub_kvp thread_kvp;
167
- cub::ArgMax arg_max;
168
-
169
- const int block_row = blockIdx.x + blockIdx.y * gridDim.x ;
170
- if (block_row >= num_rows) {
171
- return ;
172
- }
173
-
174
- const bool should_process_row = true ;
175
- const int thread_read_offset = block_row * num_experts;
176
-
177
- for (int k_idx = 0 ; k_idx < k; ++k_idx) {
178
- thread_kvp.key = 0 ;
179
- thread_kvp.value = T (-1 .f ); // This is OK because inputs are probabilities
180
-
181
- cub_kvp inp_kvp;
182
- for (int expert = threadIdx.x ; expert < num_experts; expert += TPB) {
183
- const int idx = thread_read_offset + expert;
184
- inp_kvp.key = expert;
185
- inp_kvp.value = inputs_after_softmax[idx];
186
-
187
- for (int prior_k = 0 ; prior_k < k_idx; ++prior_k) {
188
- const IdxT prior_winning_expert = indices[k * block_row + prior_k];
189
-
190
- if (prior_winning_expert == expert) {
191
- inp_kvp = thread_kvp;
192
- }
193
- }
194
-
195
- thread_kvp = arg_max (inp_kvp, thread_kvp);
196
- }
197
-
198
- const cub_kvp result_kvp =
199
- BlockReduce (tmpStorage).Reduce (thread_kvp, arg_max);
200
- if (threadIdx.x == 0 ) {
201
- const int idx = k * block_row + k_idx;
202
- // restore normalized probes
203
- output[idx] = result_kvp.value / T (softmax_max_prob[idx]);
204
- indices[idx] = should_process_row ? result_kvp.key : num_experts;
205
- source_rows[idx] = k_idx * num_rows + block_row;
206
- }
207
- __syncthreads ();
208
- }
209
- }
210
-
211
153
template <typename T, int TPB>
212
154
__launch_bounds__ (TPB) __global__ void moe_softmax (const T* input,
213
155
T* output,
@@ -262,11 +204,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
262
204
}
263
205
264
206
template <typename T, int TPB, typename IdxT = int >
265
- __launch_bounds__ (TPB) __global__ void moe_top_k (const T* inputs_after_softmax,
266
- const T* bias,
207
+ __launch_bounds__ (TPB) __global__ void group_moe_top_k (const T* inputs_after_softmax,
267
208
T* output,
268
209
IdxT* indices,
269
210
int * source_rows,
211
+ T* softmax_max_prob,
270
212
const int64_t num_experts,
271
213
const int64_t k,
272
214
const int64_t num_rows) {
@@ -293,7 +235,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
293
235
for (int expert = threadIdx.x ; expert < num_experts; expert += TPB) {
294
236
const int idx = thread_read_offset + expert;
295
237
inp_kvp.key = expert;
296
- inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
238
+ inp_kvp.value = inputs_after_softmax[idx];
297
239
298
240
for (int prior_k = 0 ; prior_k < k_idx; ++prior_k) {
299
241
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
@@ -310,101 +252,17 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
310
252
BlockReduce (tmpStorage).Reduce (thread_kvp, arg_max);
311
253
if (threadIdx.x == 0 ) {
312
254
const int idx = k * block_row + k_idx;
313
- output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key ]: result_kvp.value ;
255
+ // restore normalized probes
256
+ output[idx] = result_kvp.value / T (softmax_max_prob[idx]);
314
257
indices[idx] = should_process_row ? result_kvp.key : num_experts;
315
258
source_rows[idx] = k_idx * num_rows + block_row;
316
259
}
317
260
__syncthreads ();
318
261
}
319
262
}
320
263
321
- template <typename T, int TPB, typename IdxT = int >
322
- __launch_bounds__ (TPB) __global__ void moe_softmax_top_k_fused (const T* input,
323
- const T* bias,
324
- T* output,
325
- IdxT* indices,
326
- int * source_rows,
327
- const int64_t num_experts,
328
- const int64_t k,
329
- const int64_t num_rows) {
330
- // softmax
331
- using BlockReduce = cub::BlockReduce<float , TPB>;
332
- __shared__ typename BlockReduce::TempStorage tmpStorage;
333
-
334
- __shared__ float normalizing_factor;
335
- __shared__ float float_max;
336
-
337
- int globalIdx = blockIdx.x + blockIdx.y * gridDim.x ;
338
- if (globalIdx >= num_rows) {
339
- return ;
340
- }
341
- const int64_t thread_row_offset = globalIdx * num_experts;
342
- const int64_t idx = thread_row_offset+threadIdx.x ;
343
-
344
- cub::Sum sum;
345
-
346
- float threadData = (threadIdx.x < num_experts) ? static_cast <float >(input[idx]) :(-FLT_MAX);
347
-
348
- const float maxElem = BlockReduce (tmpStorage).Reduce (threadData, cub::Max ());
349
- if (threadIdx.x == 0 ) {
350
- float_max = maxElem;
351
- }
352
- __syncthreads ();
353
-
354
- float threadDataSub = threadData - float_max;
355
- float threadDataExp = exp (threadDataSub);
356
-
357
- const auto Z = BlockReduce (tmpStorage).Reduce (threadDataExp, sum);
358
-
359
- if (threadIdx.x == 0 ) {
360
- normalizing_factor = 1 .f / Z;
361
- }
362
- __syncthreads ();
363
-
364
- T val = T (threadDataExp * normalizing_factor);
365
-
366
- // top_k
367
- using cub_kvp = cub::KeyValuePair<int , T>;
368
- using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
369
- __shared__ typename BlockReduceP::TempStorage tmpStorageP;
370
-
371
- cub_kvp thread_kvp;
372
- cub::ArgMax arg_max;
373
-
374
- for (int k_idx = 0 ; k_idx < k; ++k_idx) {
375
- thread_kvp.key = 0 ;
376
- thread_kvp.value = T (-1 .f ); // This is OK because inputs are probabilities
377
-
378
- if (threadIdx.x < num_experts) {
379
- cub_kvp inp_kvp;
380
- int expert = threadIdx.x ;
381
- inp_kvp.key = expert;
382
- inp_kvp.value = bias ? val + bias[expert] : val;
383
-
384
- for (int prior_k = 0 ; prior_k < k_idx; ++prior_k) {
385
- const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
386
-
387
- if (prior_winning_expert == expert) {
388
- inp_kvp = thread_kvp;
389
- }
390
- }
391
- thread_kvp = arg_max (inp_kvp, thread_kvp);
392
- }
393
-
394
- const cub_kvp result_kvp =
395
- BlockReduceP (tmpStorageP).Reduce (thread_kvp, arg_max);
396
- if (threadIdx.x == 0 ) {
397
- const int cur_idx = k * globalIdx + k_idx;
398
- output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key ]) : result_kvp.value ;
399
- indices[cur_idx] = result_kvp.key ;
400
- source_rows[cur_idx] = k_idx * num_rows + globalIdx;
401
- }
402
- __syncthreads ();
403
- }
404
- }
405
-
406
- template <typename T, int TPB, typename IdxT = int >
407
- __launch_bounds__ (TPB) __global__ void moe_top_k_normed (const T* inputs_after_softmax,
264
+ template <typename T, int TPB, bool NormWeights = false , typename IdxT = int >
265
+ __launch_bounds__ (TPB) __global__ void moe_top_k (const T* inputs_after_softmax,
408
266
const T* bias,
409
267
T* output,
410
268
IdxT* indices,
@@ -427,10 +285,12 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
427
285
const bool should_process_row = true ;
428
286
const int thread_read_offset = block_row * num_experts;
429
287
T weight_sum = static_cast <T>(0 );
288
+ T* row_outputs = nullptr ;
430
289
431
- extern __shared__ char smem[];
432
-
433
- T* row_outputs = reinterpret_cast <T*>(smem);
290
+ if constexpr (NormWeights){
291
+ extern __shared__ char smem[];
292
+ row_outputs = reinterpret_cast <T*>(smem);
293
+ }
434
294
435
295
for (int k_idx = 0 ; k_idx < k; ++k_idx) {
436
296
thread_kvp.key = 0 ;
@@ -457,28 +317,32 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
457
317
BlockReduce (tmpStorage).Reduce (thread_kvp, arg_max);
458
318
if (threadIdx.x == 0 ) {
459
319
const int idx = k * block_row + k_idx;
460
- // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
461
320
indices[idx] = should_process_row ? result_kvp.key : num_experts;
462
321
source_rows[idx] = k_idx * num_rows + block_row;
463
322
464
- T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key ]: result_kvp.value ;
465
- row_outputs[k_idx] = row_out;
466
- weight_sum += row_out;
323
+ if constexpr (NormWeights){
324
+ T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key ]: result_kvp.value ;
325
+ row_outputs[k_idx] = row_out;
326
+ weight_sum += row_out;
327
+ }
328
+ else {
329
+ output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key ]: result_kvp.value ;
330
+ }
467
331
}
468
332
__syncthreads ();
469
333
}
470
- if (threadIdx.x < WARP_SIZE) {
471
- weight_sum = __shfl_sync (0xffffffff , weight_sum, 0 );
472
- }
473
-
474
- if (threadIdx.x < k) {
475
- output[k * block_row + threadIdx.x ] = row_outputs[threadIdx.x ] / weight_sum;
334
+ if constexpr (NormWeights){
335
+ if (threadIdx.x < WARP_SIZE) {
336
+ weight_sum = __shfl_sync (0xffffffff , weight_sum, 0 );
337
+ }
338
+ if (threadIdx.x < k) {
339
+ output[k * block_row + threadIdx.x ] = row_outputs[threadIdx.x ] / weight_sum;
340
+ }
476
341
}
477
342
}
478
343
479
-
480
- template <typename T, int TPB, typename IdxT = int >
481
- __launch_bounds__ (TPB) __global__ void moe_softmax_top_k_normed_fused (const T* input,
344
+ template <typename T, int TPB, bool NormWeights = false , typename IdxT = int >
345
+ __launch_bounds__ (TPB) __global__ void moe_softmax_top_k_fused (const T* input,
482
346
const T* bias,
483
347
T* output,
484
348
IdxT* indices,
@@ -532,8 +396,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
532
396
cub::ArgMax arg_max;
533
397
534
398
T weight_sum = static_cast <T>(0 );
535
- extern __shared__ char smem[];
536
- T* row_outputs = reinterpret_cast <T*>(smem);
399
+ T* row_outputs = nullptr ;
400
+ if constexpr (NormWeights){
401
+ extern __shared__ char smem[];
402
+ row_outputs = reinterpret_cast <T*>(smem);
403
+ }
537
404
538
405
for (int k_idx = 0 ; k_idx < k; ++k_idx) {
539
406
thread_kvp.key = 0 ;
@@ -560,22 +427,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
560
427
if (threadIdx.x == 0 ) {
561
428
const int cur_idx = k * globalIdx + k_idx;
562
429
563
- T row_out = bias ? (result_kvp.value - bias[result_kvp.key ]) : result_kvp.value ;
564
- row_outputs[k_idx] = row_out;
565
- weight_sum += row_out;
566
-
567
430
indices[cur_idx] = result_kvp.key ;
568
431
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
432
+
433
+ if constexpr (NormWeights) {
434
+ T row_out = bias ? (result_kvp.value - bias[result_kvp.key ]) : result_kvp.value ;
435
+ row_outputs[k_idx] = row_out;
436
+ weight_sum += row_out;
437
+ }
438
+ else {
439
+ output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key ]) : result_kvp.value ;
440
+ }
569
441
}
570
442
__syncthreads ();
571
443
}
444
+ if constexpr (NormWeights) {
445
+ if (threadIdx.x < WARP_SIZE) {
446
+ weight_sum = __shfl_sync (0xffffffff , weight_sum, 0 );
447
+ }
572
448
573
- if (threadIdx.x < WARP_SIZE) {
574
- weight_sum = __shfl_sync (0xffffffff , weight_sum, 0 );
575
- }
576
-
577
- if (threadIdx.x < k) {
578
- output[k * globalIdx + threadIdx.x ] = row_outputs[threadIdx.x ] / weight_sum;
449
+ if (threadIdx.x < k) {
450
+ output[k * globalIdx + threadIdx.x ] = row_outputs[threadIdx.x ] / weight_sum;
451
+ }
579
452
}
580
453
}
581
454
@@ -1015,7 +888,7 @@ static void run(const T* input,
1015
888
group_experts,
1016
889
softmax_num_rows);
1017
890
const auto config_topk = Get1DBlocksAnd2DGridsMoe (num_rows);
1018
- moe_top_k <T, TPB>
891
+ group_moe_top_k <T, TPB>
1019
892
<<<config_topk.block_per_grid , TPB, 0 , stream>>>(softmax,
1020
893
output,
1021
894
indices,
0 commit comments