Skip to content

Commit 8a50650

Browse files
authored
[BugFix] Fix EP MoE prefill function (#4101)
1 parent 1aab1c8 commit 8a50650

File tree

1 file changed

+63
-129
lines changed

1 file changed

+63
-129
lines changed

custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu

Lines changed: 63 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -453,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
453453
auto place = input.place();
454454
const int gridx = min(132 * 8, num_rows);
455455
if (moe_quant_type == "w4a8") {
456-
if (num_experts_per_rank == 8) {
457-
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
458-
input.data<data_t>(),
459-
topk_ids.data<int64_t>(),
460-
topk_weights.data<float>(),
461-
token_nums_per_expert.data<int>(),
462-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
463-
moe_topk,
464-
num_rows,
465-
token_nums_this_rank,
466-
hidden_size,
467-
permute_input->data<int8_t>(),
468-
permute_indices_per_token->data<int>(),
469-
dst_weights->data<float>(),
470-
dst_indices->data<int>(),
471-
cumsum_idx_gpu->data<int>(),
472-
token_nums_per_expert_cumsum->data<int64_t>(),
473-
expert_idx_per_token->data<int64_t>(),
474-
127.0,
475-
-127.0
476-
);
477-
} else if (num_experts_per_rank == 16) {
478-
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
479-
input.data<data_t>(),
480-
topk_ids.data<int64_t>(),
481-
topk_weights.data<float>(),
482-
token_nums_per_expert.data<int>(),
483-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
484-
moe_topk,
485-
num_rows,
486-
token_nums_this_rank,
487-
hidden_size,
488-
permute_input->data<int8_t>(),
489-
permute_indices_per_token->data<int>(),
490-
dst_weights->data<float>(),
491-
dst_indices->data<int>(),
492-
cumsum_idx_gpu->data<int>(),
493-
token_nums_per_expert_cumsum->data<int64_t>(),
494-
expert_idx_per_token->data<int64_t>(),
495-
127.0,
496-
-127.0
497-
);
498-
}
456+
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
457+
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
458+
input.data<data_t>(),
459+
topk_ids.data<int64_t>(),
460+
topk_weights.data<float>(),
461+
token_nums_per_expert.data<int>(),
462+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
463+
moe_topk,
464+
num_rows,
465+
token_nums_this_rank,
466+
hidden_size,
467+
permute_input->data<int8_t>(),
468+
permute_indices_per_token->data<int>(),
469+
dst_weights->data<float>(),
470+
dst_indices->data<int>(),
471+
cumsum_idx_gpu->data<int>(),
472+
token_nums_per_expert_cumsum->data<int64_t>(),
473+
expert_idx_per_token->data<int64_t>(),
474+
127.0,
475+
-127.0
476+
);)
499477
} else if (moe_quant_type == "w4afp8") {
500-
if (num_experts_per_rank == 8) {
501-
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
502-
input.data<data_t>(),
503-
topk_ids.data<int64_t>(),
504-
topk_weights.data<float>(),
505-
token_nums_per_expert.data<int>(),
506-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
507-
moe_topk,
508-
num_rows,
509-
token_nums_this_rank,
510-
hidden_size,
511-
permute_input->data<data_t_fp8>(),
512-
permute_indices_per_token->data<int>(),
513-
dst_weights->data<float>(),
514-
dst_indices->data<int>(),
515-
cumsum_idx_gpu->data<int>(),
516-
token_nums_per_expert_cumsum->data<int64_t>(),
517-
expert_idx_per_token->data<int64_t>(),
518-
448.0f,
519-
-448.0f
520-
);
521-
} else if (num_experts_per_rank == 16) {
522-
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
523-
input.data<data_t>(),
524-
topk_ids.data<int64_t>(),
525-
topk_weights.data<float>(),
526-
token_nums_per_expert.data<int>(),
527-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
528-
moe_topk,
529-
num_rows,
530-
token_nums_this_rank,
531-
hidden_size,
532-
permute_input->data<data_t_fp8>(),
533-
permute_indices_per_token->data<int>(),
534-
dst_weights->data<float>(),
535-
dst_indices->data<int>(),
536-
cumsum_idx_gpu->data<int>(),
537-
token_nums_per_expert_cumsum->data<int64_t>(),
538-
expert_idx_per_token->data<int64_t>(),
539-
448.0f,
540-
-448.0f
541-
);
542-
}
478+
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
479+
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
480+
input.data<data_t>(),
481+
topk_ids.data<int64_t>(),
482+
topk_weights.data<float>(),
483+
token_nums_per_expert.data<int>(),
484+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
485+
moe_topk,
486+
num_rows,
487+
token_nums_this_rank,
488+
hidden_size,
489+
permute_input->data<data_t_fp8>(),
490+
permute_indices_per_token->data<int>(),
491+
dst_weights->data<float>(),
492+
dst_indices->data<int>(),
493+
cumsum_idx_gpu->data<int>(),
494+
token_nums_per_expert_cumsum->data<int64_t>(),
495+
expert_idx_per_token->data<int64_t>(),
496+
448.0f,
497+
-448.0f
498+
);)
543499
} else {
544-
if (num_experts_per_rank == 8) {
545-
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
546-
input.data<data_t>(),
547-
topk_ids.data<int64_t>(),
548-
topk_weights.data<float>(),
549-
token_nums_per_expert.data<int>(),
550-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
551-
moe_topk,
552-
num_rows,
553-
token_nums_this_rank,
554-
hidden_size,
555-
permute_input->data<data_t>(),
556-
permute_indices_per_token->data<int>(),
557-
dst_weights->data<float>(),
558-
dst_indices->data<int>(),
559-
cumsum_idx_gpu->data<int>(),
560-
token_nums_per_expert_cumsum->data<int64_t>(),
561-
expert_idx_per_token->data<int64_t>(),
562-
127.0,
563-
-127.0
564-
);
565-
} else if (num_experts_per_rank == 16) {
566-
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
567-
input.data<data_t>(),
568-
topk_ids.data<int64_t>(),
569-
topk_weights.data<float>(),
570-
token_nums_per_expert.data<int>(),
571-
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
572-
moe_topk,
573-
num_rows,
574-
token_nums_this_rank,
575-
hidden_size,
576-
permute_input->data<data_t>(),
577-
permute_indices_per_token->data<int>(),
578-
dst_weights->data<float>(),
579-
dst_indices->data<int>(),
580-
cumsum_idx_gpu->data<int>(),
581-
token_nums_per_expert_cumsum->data<int64_t>(),
582-
expert_idx_per_token->data<int64_t>(),
583-
127.0,
584-
-127.0
585-
);
586-
}
500+
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
501+
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
502+
input.data<data_t>(),
503+
topk_ids.data<int64_t>(),
504+
topk_weights.data<float>(),
505+
token_nums_per_expert.data<int>(),
506+
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
507+
moe_topk,
508+
num_rows,
509+
token_nums_this_rank,
510+
hidden_size,
511+
permute_input->data<data_t>(),
512+
permute_indices_per_token->data<int>(),
513+
dst_weights->data<float>(),
514+
dst_indices->data<int>(),
515+
cumsum_idx_gpu->data<int>(),
516+
token_nums_per_expert_cumsum->data<int64_t>(),
517+
expert_idx_per_token->data<int64_t>(),
518+
127.0,
519+
-127.0
520+
);)
587521
}
588522
}
589523

0 commit comments

Comments
 (0)