Skip to content

Commit 8224b21

Browse files
authored
Refactor moe_topk_select op to use apply_norm_weight as a template parameter (#3345)
* Refactor moe_topk_select op to use apply_norm_weight as a template parameter * update test
1 parent eda83ca commit 8224b21

File tree

4 files changed

+145
-184
lines changed

4 files changed

+145
-184
lines changed

custom_ops/gpu_ops/moe/fused_moe_op.h

Lines changed: 51 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -150,64 +150,6 @@ __launch_bounds__(TPB) __global__
150150
}
151151
}
152152

153-
template <typename T, int TPB, typename IdxT = int>
154-
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
155-
T* output,
156-
IdxT* indices,
157-
int* source_rows,
158-
T* softmax_max_prob,
159-
const int64_t num_experts,
160-
const int64_t k,
161-
const int64_t num_rows) {
162-
using cub_kvp = cub::KeyValuePair<int, T>;
163-
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
164-
__shared__ typename BlockReduce::TempStorage tmpStorage;
165-
166-
cub_kvp thread_kvp;
167-
cub::ArgMax arg_max;
168-
169-
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
170-
if (block_row >= num_rows) {
171-
return;
172-
}
173-
174-
const bool should_process_row = true;
175-
const int thread_read_offset = block_row * num_experts;
176-
177-
for (int k_idx = 0; k_idx < k; ++k_idx) {
178-
thread_kvp.key = 0;
179-
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
180-
181-
cub_kvp inp_kvp;
182-
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
183-
const int idx = thread_read_offset + expert;
184-
inp_kvp.key = expert;
185-
inp_kvp.value = inputs_after_softmax[idx];
186-
187-
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
188-
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
189-
190-
if (prior_winning_expert == expert) {
191-
inp_kvp = thread_kvp;
192-
}
193-
}
194-
195-
thread_kvp = arg_max(inp_kvp, thread_kvp);
196-
}
197-
198-
const cub_kvp result_kvp =
199-
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
200-
if (threadIdx.x == 0) {
201-
const int idx = k * block_row + k_idx;
202-
// restore normalized probes
203-
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
204-
indices[idx] = should_process_row ? result_kvp.key : num_experts;
205-
source_rows[idx] = k_idx * num_rows + block_row;
206-
}
207-
__syncthreads();
208-
}
209-
}
210-
211153
template <typename T, int TPB>
212154
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
213155
T* output,
@@ -262,11 +204,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
262204
}
263205

264206
template <typename T, int TPB, typename IdxT = int>
265-
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
266-
const T* bias,
207+
__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
267208
T* output,
268209
IdxT* indices,
269210
int* source_rows,
211+
T* softmax_max_prob,
270212
const int64_t num_experts,
271213
const int64_t k,
272214
const int64_t num_rows) {
@@ -293,7 +235,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
293235
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
294236
const int idx = thread_read_offset + expert;
295237
inp_kvp.key = expert;
296-
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
238+
inp_kvp.value = inputs_after_softmax[idx];
297239

298240
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
299241
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
@@ -310,101 +252,17 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
310252
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
311253
if (threadIdx.x == 0) {
312254
const int idx = k * block_row + k_idx;
313-
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
255+
// restore normalized probes
256+
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
314257
indices[idx] = should_process_row ? result_kvp.key : num_experts;
315258
source_rows[idx] = k_idx * num_rows + block_row;
316259
}
317260
__syncthreads();
318261
}
319262
}
320263

321-
template <typename T, int TPB, typename IdxT = int>
322-
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
323-
const T* bias,
324-
T* output,
325-
IdxT* indices,
326-
int* source_rows,
327-
const int64_t num_experts,
328-
const int64_t k,
329-
const int64_t num_rows) {
330-
// softmax
331-
using BlockReduce = cub::BlockReduce<float, TPB>;
332-
__shared__ typename BlockReduce::TempStorage tmpStorage;
333-
334-
__shared__ float normalizing_factor;
335-
__shared__ float float_max;
336-
337-
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
338-
if (globalIdx >= num_rows) {
339-
return;
340-
}
341-
const int64_t thread_row_offset = globalIdx * num_experts;
342-
const int64_t idx = thread_row_offset+threadIdx.x;
343-
344-
cub::Sum sum;
345-
346-
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
347-
348-
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
349-
if (threadIdx.x == 0) {
350-
float_max = maxElem;
351-
}
352-
__syncthreads();
353-
354-
float threadDataSub = threadData - float_max;
355-
float threadDataExp = exp(threadDataSub);
356-
357-
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
358-
359-
if (threadIdx.x == 0) {
360-
normalizing_factor = 1.f / Z;
361-
}
362-
__syncthreads();
363-
364-
T val = T(threadDataExp * normalizing_factor);
365-
366-
// top_k
367-
using cub_kvp = cub::KeyValuePair<int, T>;
368-
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
369-
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
370-
371-
cub_kvp thread_kvp;
372-
cub::ArgMax arg_max;
373-
374-
for (int k_idx = 0; k_idx < k; ++k_idx) {
375-
thread_kvp.key = 0;
376-
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
377-
378-
if (threadIdx.x < num_experts) {
379-
cub_kvp inp_kvp;
380-
int expert = threadIdx.x;
381-
inp_kvp.key = expert;
382-
inp_kvp.value = bias ? val + bias[expert] : val;
383-
384-
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
385-
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
386-
387-
if (prior_winning_expert == expert) {
388-
inp_kvp = thread_kvp;
389-
}
390-
}
391-
thread_kvp = arg_max(inp_kvp, thread_kvp);
392-
}
393-
394-
const cub_kvp result_kvp =
395-
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
396-
if (threadIdx.x == 0) {
397-
const int cur_idx = k * globalIdx + k_idx;
398-
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
399-
indices[cur_idx] = result_kvp.key;
400-
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
401-
}
402-
__syncthreads();
403-
}
404-
}
405-
406-
template <typename T, int TPB, typename IdxT = int>
407-
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
264+
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
265+
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
408266
const T* bias,
409267
T* output,
410268
IdxT* indices,
@@ -427,10 +285,12 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
427285
const bool should_process_row = true;
428286
const int thread_read_offset = block_row * num_experts;
429287
T weight_sum = static_cast<T>(0);
288+
T* row_outputs = nullptr;
430289

431-
extern __shared__ char smem[];
432-
433-
T* row_outputs = reinterpret_cast<T*>(smem);
290+
if constexpr (NormWeights){
291+
extern __shared__ char smem[];
292+
row_outputs = reinterpret_cast<T*>(smem);
293+
}
434294

435295
for (int k_idx = 0; k_idx < k; ++k_idx) {
436296
thread_kvp.key = 0;
@@ -457,28 +317,32 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so
457317
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
458318
if (threadIdx.x == 0) {
459319
const int idx = k * block_row + k_idx;
460-
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
461320
indices[idx] = should_process_row ? result_kvp.key : num_experts;
462321
source_rows[idx] = k_idx * num_rows + block_row;
463322

464-
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
465-
row_outputs[k_idx] = row_out;
466-
weight_sum += row_out;
323+
if constexpr (NormWeights){
324+
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
325+
row_outputs[k_idx] = row_out;
326+
weight_sum += row_out;
327+
}
328+
else{
329+
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
330+
}
467331
}
468332
__syncthreads();
469333
}
470-
if (threadIdx.x < WARP_SIZE) {
471-
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
472-
}
473-
474-
if (threadIdx.x < k) {
475-
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
334+
if constexpr (NormWeights){
335+
if (threadIdx.x < WARP_SIZE) {
336+
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
337+
}
338+
if (threadIdx.x < k) {
339+
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
340+
}
476341
}
477342
}
478343

479-
480-
template <typename T, int TPB, typename IdxT = int>
481-
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
344+
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
345+
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
482346
const T* bias,
483347
T* output,
484348
IdxT* indices,
@@ -532,8 +396,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
532396
cub::ArgMax arg_max;
533397

534398
T weight_sum = static_cast<T>(0);
535-
extern __shared__ char smem[];
536-
T* row_outputs = reinterpret_cast<T*>(smem);
399+
T* row_outputs = nullptr;
400+
if constexpr (NormWeights){
401+
extern __shared__ char smem[];
402+
row_outputs = reinterpret_cast<T*>(smem);
403+
}
537404

538405
for (int k_idx = 0; k_idx < k; ++k_idx) {
539406
thread_kvp.key = 0;
@@ -560,22 +427,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
560427
if (threadIdx.x == 0) {
561428
const int cur_idx = k * globalIdx + k_idx;
562429

563-
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
564-
row_outputs[k_idx] = row_out;
565-
weight_sum += row_out;
566-
567430
indices[cur_idx] = result_kvp.key;
568431
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
432+
433+
if constexpr (NormWeights) {
434+
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
435+
row_outputs[k_idx] = row_out;
436+
weight_sum += row_out;
437+
}
438+
else {
439+
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
440+
}
569441
}
570442
__syncthreads();
571443
}
444+
if constexpr (NormWeights) {
445+
if (threadIdx.x < WARP_SIZE) {
446+
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
447+
}
572448

573-
if (threadIdx.x < WARP_SIZE) {
574-
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
575-
}
576-
577-
if (threadIdx.x < k) {
578-
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
449+
if (threadIdx.x < k) {
450+
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
451+
}
579452
}
580453
}
581454

@@ -1015,7 +888,7 @@ static void run(const T* input,
1015888
group_experts,
1016889
softmax_num_rows);
1017890
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
1018-
moe_top_k<T, TPB>
891+
group_moe_top_k<T, TPB>
1019892
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
1020893
output,
1021894
indices,

custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void moe_redundant_topk_select_kernel(const T* input,
102102
else {
103103
assert(k<=TPB);
104104
if (apply_norm_weight) {
105-
moe_softmax_top_k_normed_fused<T, TPB>
105+
moe_softmax_top_k_fused<T, TPB, true>
106106
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
107107
bias,
108108
output,
@@ -112,7 +112,7 @@ void moe_redundant_topk_select_kernel(const T* input,
112112
k,
113113
num_rows);
114114
} else {
115-
moe_softmax_top_k_fused<T, TPB>
115+
moe_softmax_top_k_fused<T, TPB, false>
116116
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
117117
bias,
118118
output,

custom_ops/gpu_ops/moe/moe_topk_select.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void moe_topk_select_kernel(const T* input,
6868
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
6969
input, softmax, num_experts, num_rows);
7070
if (apply_norm_weight) {
71-
moe_top_k_normed<T, TPB>
71+
moe_top_k<T, TPB, true>
7272
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
7373
bias,
7474
output,
@@ -78,7 +78,7 @@ void moe_topk_select_kernel(const T* input,
7878
k,
7979
num_rows);
8080
} else {
81-
moe_top_k<T, TPB>
81+
moe_top_k<T, TPB, false>
8282
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
8383
bias,
8484
output,
@@ -93,7 +93,7 @@ void moe_topk_select_kernel(const T* input,
9393
else {
9494
assert(k<=TPB);
9595
if (apply_norm_weight) {
96-
moe_softmax_top_k_normed_fused<T, TPB>
96+
moe_softmax_top_k_fused<T, TPB, true>
9797
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
9898
bias,
9999
output,
@@ -103,7 +103,7 @@ void moe_topk_select_kernel(const T* input,
103103
k,
104104
num_rows);
105105
} else {
106-
moe_softmax_top_k_fused<T, TPB>
106+
moe_softmax_top_k_fused<T, TPB, false>
107107
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
108108
bias,
109109
output,

0 commit comments

Comments
 (0)