@@ -43,7 +43,11 @@ struct expert_infos {
43
43
}
44
44
};
45
45
46
- template <typename X_T, typename routemap_T, typename probs_T, bool has_scale>
46
+ template <typename X_T,
47
+ typename routemap_T,
48
+ typename probs_T,
49
+ bool has_scale,
50
+ bool do_gather>
47
51
__global__ __launch_bounds__ (512 ) void tokens_unzip_stable_kernel(
48
52
const X_T *__restrict__ X,
49
53
const routemap_T *__restrict__ routemap_topk,
@@ -130,17 +134,19 @@ __global__ __launch_bounds__(512) void tokens_unzip_stable_kernel(
130
134
if (proposed_row_idx == -1 ) continue ; // no memcpy
131
135
if (threadIdx .x == 0 )
132
136
probs_unzipped[proposed_row_idx] = this_expert_token_info.expert_probs ;
133
- // vec copy
134
- if constexpr (has_scale) {
137
+ if constexpr (do_gather) {
138
+ // vec copy
139
+ if constexpr (has_scale) {
140
+ vectorized_memcpy (&XScale[(int64_t )row * (int64_t )scale_length],
141
+ &XScale_unzipped[(int64_t )proposed_row_idx *
142
+ (int64_t )scale_length],
143
+ scale_length);
144
+ }
135
145
vectorized_memcpy (
136
- &XScale [(int64_t )row * (int64_t )scale_length ],
137
- &XScale_unzipped [(int64_t )proposed_row_idx * (int64_t )scale_length ],
138
- scale_length );
146
+ &X [(int64_t )row * (int64_t )token_length ],
147
+ &X_unzipped [(int64_t )proposed_row_idx * (int64_t )token_length ],
148
+ token_length );
139
149
}
140
- vectorized_memcpy (
141
- &X[(int64_t )row * (int64_t )token_length],
142
- &X_unzipped[(int64_t )proposed_row_idx * (int64_t )token_length],
143
- token_length);
144
150
}
145
151
}
146
152
}
@@ -160,7 +166,8 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx,
160
166
const int token_length,
161
167
const int topk, // deprecated
162
168
const int num_experts,
163
- const int scale_length) {
169
+ const int scale_length,
170
+ const bool do_gather) {
164
171
dim3 grid, block;
165
172
grid.x =
166
173
(total_zipped_tokens_num + CUMSUM_BLOCK_SIZE - 1 ) / CUMSUM_BLOCK_SIZE;
@@ -169,33 +176,41 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx,
169
176
#define DTYPE_CASE (dtype, type ) dtype == phi::DataType::type
170
177
#define GET_DATA (tensor, type ) tensor.data<type>()
171
178
#define GET_PTR_DATA (tensor, type ) tensor->data<type>()
172
- #define DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE ) \
173
- auto kernel = tokens_unzip_stable_kernel<TOKEN_T, INT_T, PROB_T, HAS_SCALE>; \
174
- kernel<<<grid, block, 0 , dev_ctx.stream()>>> ( \
175
- GET_DATA (X, TOKEN_T), \
176
- GET_DATA (expert_routemap_topk, INT_T), \
177
- GET_DATA (expert_prob_topk, PROB_T), \
178
- XScale ? XScale.get_ptr ()->data <float >() : nullptr , \
179
- GET_DATA (expert_offsets, int ), \
180
- GET_PTR_DATA (X_unzipped, TOKEN_T), \
181
- GET_PTR_DATA (zipped_expertwise_rowmap, INT_T), \
182
- GET_PTR_DATA (token_prob_unzipped, PROB_T), \
183
- XScale_unzipped->data <float >(), \
184
- global_expertwise_block_cumsum->data <int >(), \
185
- total_zipped_tokens_num, \
186
- token_length, \
187
- scale_length, \
188
- num_experts, \
179
+ #define DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER ) \
180
+ auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
181
+ INT_T, \
182
+ PROB_T, \
183
+ HAS_SCALE, \
184
+ DO_GATHER>; \
185
+ kernel<<<grid, block, 0 , dev_ctx.stream()>>> ( \
186
+ GET_DATA (X, TOKEN_T), \
187
+ GET_DATA (expert_routemap_topk, INT_T), \
188
+ GET_DATA (expert_prob_topk, PROB_T), \
189
+ XScale ? XScale.get_ptr ()->data <float >() : nullptr , \
190
+ GET_DATA (expert_offsets, int ), \
191
+ GET_PTR_DATA (X_unzipped, TOKEN_T), \
192
+ GET_PTR_DATA (zipped_expertwise_rowmap, INT_T), \
193
+ GET_PTR_DATA (token_prob_unzipped, PROB_T), \
194
+ XScale_unzipped->data <float >(), \
195
+ global_expertwise_block_cumsum->data <int >(), \
196
+ total_zipped_tokens_num, \
197
+ token_length, \
198
+ scale_length, \
199
+ num_experts, \
189
200
topk);
190
201
191
- #define HANDLE_EXPERT_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE ) \
192
- DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE)
202
+ #define HANDLE_GATHER_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE ) \
203
+ if (do_gather) { \
204
+ DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE, true ) \
205
+ } else { \
206
+ DISPATCH_CASE (TOKEN_T, PROB_T, INT_T, HAS_SCALE, false ) \
207
+ }
193
208
194
209
#define HANDLE_TOKEN_TYPE (PROB_T, INT_T ) \
195
210
if (DTYPE_CASE (X.dtype (), BFLOAT16)) { \
196
- HANDLE_EXPERT_CASE (phi::bfloat16, PROB_T, INT_T, false ) \
211
+ HANDLE_GATHER_CASE (phi::bfloat16, PROB_T, INT_T, false ) \
197
212
} else if (DTYPE_CASE (X.dtype (), FLOAT8_E4M3FN)) { \
198
- HANDLE_EXPERT_CASE (phi::float8_e4m3fn, PROB_T, INT_T, true ) \
213
+ HANDLE_GATHER_CASE (phi::float8_e4m3fn, PROB_T, INT_T, true ) \
199
214
}
200
215
201
216
#define HANDLE_PROB_TYPE (INT_T ) \
@@ -226,6 +241,7 @@ void MoePermuteKernel(const Context &dev_ctx,
226
241
const int num_experts,
227
242
const std::vector<int > &tokens_per_expert,
228
243
const int padding_multiplex,
244
+ const bool do_gather,
229
245
DenseTensor *X_unzipped,
230
246
DenseTensor *zipped_expertwise_rowmap,
231
247
DenseTensor *token_prob_unzipped,
@@ -341,7 +357,8 @@ void MoePermuteKernel(const Context &dev_ctx,
341
357
cols,
342
358
topk_calculated,
343
359
num_experts,
344
- quanted_cols);
360
+ quanted_cols,
361
+ do_gather);
345
362
}
346
363
#undef CUMSUM_BLOCK_SIZE
347
364
#undef CUMSUM_INVALID_TAG
0 commit comments