@@ -58,6 +58,7 @@ __global__ void tokens_unzip_stable_kernel(
58
58
local_expert_offsets[i] = expert_base_offset.data [i];
59
59
local_cumsum[i] = 0 ;
60
60
}
61
+ const int base_row_idx = blockIdx .x * CUMSUM_BLOCK_SIZE;
61
62
__shared__ int shared_expert_rowmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
62
63
__shared__ probs_T shared_expert_probmap[CUMSUM_BLOCK_SIZE][MAX_NUM_EXPERTS];
63
64
@@ -238,13 +239,7 @@ void MoePermuteKernel(const Context &dev_ctx,
238
239
" value." ,
239
240
MAX_NUM_EXPERTS,
240
241
num_experts));
241
- if (X.numel () == 0 ) {
242
- dev_ctx.template Alloc <float >(XScale_unzipped);
243
- dev_ctx.template Alloc <int >(zipped_expertwise_rowmap);
244
- dev_ctx.template Alloc <T>(X_unzipped);
245
- dev_ctx.template Alloc <float >(token_prob_unzipped);
246
- return ;
247
- }
242
+
248
243
const int quanted_cols = (XScale) ? XScale.get_ptr ()->dims ()[1 ] : 0 ;
249
244
expert_base_offset expert_offset;
250
245
int tokens_cumulated = 0 ;
@@ -270,46 +265,26 @@ void MoePermuteKernel(const Context &dev_ctx,
270
265
dev_ctx.template Alloc <float >(XScale_unzipped);
271
266
dev_ctx.template Alloc <int >(zipped_expertwise_rowmap);
272
267
dev_ctx.template Alloc <T>(X_unzipped);
268
+ dev_ctx.template Alloc <float >(token_prob_unzipped);
273
269
auto X_unzipped_ptr = reinterpret_cast <void *>(X_unzipped->data <T>());
274
- for (int i = 0 ; i < num_experts; i++) {
275
- int next_expert_offset =
276
- i < num_experts - 1 ? expert_offset.data [i + 1 ] : output_rows;
277
- int invalid_rows =
278
- next_expert_offset - expert_offset.data [i] - tokens_per_expert[i];
279
- cudaMemsetAsync (X_unzipped_ptr + tokens_per_expert[i] * sizeof (T),
280
- 0 ,
281
- sizeof (T) * invalid_rows * cols,
282
- dev_ctx.stream ());
283
- }
270
+ cudaMemsetAsync (
271
+ X_unzipped_ptr, 0 , sizeof (T) * output_rows * cols, dev_ctx.stream ());
284
272
if (XScale) {
285
273
auto XScale_unzipped_ptr =
286
274
reinterpret_cast <void *>(XScale_unzipped->data <float >());
287
- for (int i = 0 ; i < num_experts; i++) {
288
- int next_expert_offset =
289
- i < num_experts - 1 ? expert_offset.data [i + 1 ] : output_rows;
290
- int invalid_rows =
291
- next_expert_offset - expert_offset.data [i] - tokens_per_expert[i];
292
- cudaMemsetAsync (
293
- XScale_unzipped_ptr + tokens_per_expert[i] * sizeof (float ),
294
- 0 ,
295
- sizeof (float ) * invalid_rows * quanted_cols,
296
- dev_ctx.stream ());
297
- }
275
+ cudaMemsetAsync (XScale_unzipped_ptr,
276
+ 0 ,
277
+ sizeof (float ) * output_rows * quanted_cols,
278
+ dev_ctx.stream ());
298
279
}
299
- dev_ctx. template Alloc < float >(token_prob_unzipped);
280
+
300
281
auto token_prob_unzipped_ptr =
301
282
reinterpret_cast <void *>(token_prob_unzipped->data <float >());
302
- for (int i = 0 ; i < num_experts; i++) {
303
- int next_expert_offset =
304
- i < num_experts - 1 ? expert_offset.data [i + 1 ] : output_rows;
305
- int invalid_rows =
306
- next_expert_offset - expert_offset.data [i] - tokens_per_expert[i];
307
- cudaMemsetAsync (
308
- token_prob_unzipped_ptr + tokens_per_expert[i] * sizeof (float ),
309
- 0 ,
310
- sizeof (float ) * invalid_rows,
311
- dev_ctx.stream ());
312
- }
283
+ cudaMemsetAsync (token_prob_unzipped_ptr,
284
+ 0 ,
285
+ sizeof (float ) * output_rows,
286
+ dev_ctx.stream ());
287
+ if (X.numel () == 0 ) return ;
313
288
const int cumsum_blocknum =
314
289
(rows + CUMSUM_BLOCK_SIZE - 1 ) / CUMSUM_BLOCK_SIZE;
315
290
DenseTensor global_expertwise_block_cumsum =
0 commit comments