@@ -453,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
453
453
auto place = input.place ();
454
454
const int gridx = min (132 * 8 , num_rows);
455
455
if (moe_quant_type == " w4a8" ) {
456
- if (num_experts_per_rank == 8 ) {
457
- permute_x_kernel<data_t , int8_t , 8 ><<<gridx, 512 , 0 , stream>>> (
458
- input.data <data_t >(),
459
- topk_ids.data <int64_t >(),
460
- topk_weights.data <float >(),
461
- token_nums_per_expert.data <int >(),
462
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
463
- moe_topk,
464
- num_rows,
465
- token_nums_this_rank,
466
- hidden_size,
467
- permute_input->data <int8_t >(),
468
- permute_indices_per_token->data <int >(),
469
- dst_weights->data <float >(),
470
- dst_indices->data <int >(),
471
- cumsum_idx_gpu->data <int >(),
472
- token_nums_per_expert_cumsum->data <int64_t >(),
473
- expert_idx_per_token->data <int64_t >(),
474
- 127.0 ,
475
- -127.0
476
- );
477
- } else if (num_experts_per_rank == 16 ) {
478
- permute_x_kernel<data_t , int8_t , 16 ><<<gridx, 512 , 0 , stream>>> (
479
- input.data <data_t >(),
480
- topk_ids.data <int64_t >(),
481
- topk_weights.data <float >(),
482
- token_nums_per_expert.data <int >(),
483
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
484
- moe_topk,
485
- num_rows,
486
- token_nums_this_rank,
487
- hidden_size,
488
- permute_input->data <int8_t >(),
489
- permute_indices_per_token->data <int >(),
490
- dst_weights->data <float >(),
491
- dst_indices->data <int >(),
492
- cumsum_idx_gpu->data <int >(),
493
- token_nums_per_expert_cumsum->data <int64_t >(),
494
- expert_idx_per_token->data <int64_t >(),
495
- 127.0 ,
496
- -127.0
497
- );
498
- }
456
+ DISPATCH_NUM_EXPERTS_PER_RANK (num_experts_per_rank, NUM_EXPERTS_PER_RANK,
457
+ permute_x_kernel<data_t , int8_t , NUM_EXPERTS_PER_RANK><<<gridx, 512 , 0 , stream>>> (
458
+ input.data <data_t >(),
459
+ topk_ids.data <int64_t >(),
460
+ topk_weights.data <float >(),
461
+ token_nums_per_expert.data <int >(),
462
+ up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
463
+ moe_topk,
464
+ num_rows,
465
+ token_nums_this_rank,
466
+ hidden_size,
467
+ permute_input->data <int8_t >(),
468
+ permute_indices_per_token->data <int >(),
469
+ dst_weights->data <float >(),
470
+ dst_indices->data <int >(),
471
+ cumsum_idx_gpu->data <int >(),
472
+ token_nums_per_expert_cumsum->data <int64_t >(),
473
+ expert_idx_per_token->data <int64_t >(),
474
+ 127.0 ,
475
+ -127.0
476
+ );)
499
477
} else if (moe_quant_type == " w4afp8" ) {
500
- if (num_experts_per_rank == 8 ) {
501
- permute_x_kernel<data_t , data_t_fp8, 8 , 512 ><<<gridx, 512 , 0 , stream>>> (
502
- input.data <data_t >(),
503
- topk_ids.data <int64_t >(),
504
- topk_weights.data <float >(),
505
- token_nums_per_expert.data <int >(),
506
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
507
- moe_topk,
508
- num_rows,
509
- token_nums_this_rank,
510
- hidden_size,
511
- permute_input->data <data_t_fp8>(),
512
- permute_indices_per_token->data <int >(),
513
- dst_weights->data <float >(),
514
- dst_indices->data <int >(),
515
- cumsum_idx_gpu->data <int >(),
516
- token_nums_per_expert_cumsum->data <int64_t >(),
517
- expert_idx_per_token->data <int64_t >(),
518
- 448 .0f ,
519
- -448 .0f
520
- );
521
- } else if (num_experts_per_rank == 16 ) {
522
- permute_x_kernel<data_t , data_t_fp8, 16 , 512 ><<<gridx, 512 , 0 , stream>>> (
523
- input.data <data_t >(),
524
- topk_ids.data <int64_t >(),
525
- topk_weights.data <float >(),
526
- token_nums_per_expert.data <int >(),
527
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
528
- moe_topk,
529
- num_rows,
530
- token_nums_this_rank,
531
- hidden_size,
532
- permute_input->data <data_t_fp8>(),
533
- permute_indices_per_token->data <int >(),
534
- dst_weights->data <float >(),
535
- dst_indices->data <int >(),
536
- cumsum_idx_gpu->data <int >(),
537
- token_nums_per_expert_cumsum->data <int64_t >(),
538
- expert_idx_per_token->data <int64_t >(),
539
- 448 .0f ,
540
- -448 .0f
541
- );
542
- }
478
+ DISPATCH_NUM_EXPERTS_PER_RANK (num_experts_per_rank, NUM_EXPERTS_PER_RANK,
479
+ permute_x_kernel<data_t , data_t_fp8, NUM_EXPERTS_PER_RANK, 512 ><<<gridx, 512 , 0 , stream>>> (
480
+ input.data <data_t >(),
481
+ topk_ids.data <int64_t >(),
482
+ topk_weights.data <float >(),
483
+ token_nums_per_expert.data <int >(),
484
+ up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
485
+ moe_topk,
486
+ num_rows,
487
+ token_nums_this_rank,
488
+ hidden_size,
489
+ permute_input->data <data_t_fp8>(),
490
+ permute_indices_per_token->data <int >(),
491
+ dst_weights->data <float >(),
492
+ dst_indices->data <int >(),
493
+ cumsum_idx_gpu->data <int >(),
494
+ token_nums_per_expert_cumsum->data <int64_t >(),
495
+ expert_idx_per_token->data <int64_t >(),
496
+ 448 .0f ,
497
+ -448 .0f
498
+ );)
543
499
} else {
544
- if (num_experts_per_rank == 8 ) {
545
- permute_x_kernel<data_t , data_t , 8 ><<<gridx, 512 , 0 , stream>>> (
546
- input.data <data_t >(),
547
- topk_ids.data <int64_t >(),
548
- topk_weights.data <float >(),
549
- token_nums_per_expert.data <int >(),
550
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
551
- moe_topk,
552
- num_rows,
553
- token_nums_this_rank,
554
- hidden_size,
555
- permute_input->data <data_t >(),
556
- permute_indices_per_token->data <int >(),
557
- dst_weights->data <float >(),
558
- dst_indices->data <int >(),
559
- cumsum_idx_gpu->data <int >(),
560
- token_nums_per_expert_cumsum->data <int64_t >(),
561
- expert_idx_per_token->data <int64_t >(),
562
- 127.0 ,
563
- -127.0
564
- );
565
- } else if (num_experts_per_rank == 16 ) {
566
- permute_x_kernel<data_t , data_t , 16 ><<<gridx, 512 , 0 , stream>>> (
567
- input.data <data_t >(),
568
- topk_ids.data <int64_t >(),
569
- topk_weights.data <float >(),
570
- token_nums_per_expert.data <int >(),
571
- up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
572
- moe_topk,
573
- num_rows,
574
- token_nums_this_rank,
575
- hidden_size,
576
- permute_input->data <data_t >(),
577
- permute_indices_per_token->data <int >(),
578
- dst_weights->data <float >(),
579
- dst_indices->data <int >(),
580
- cumsum_idx_gpu->data <int >(),
581
- token_nums_per_expert_cumsum->data <int64_t >(),
582
- expert_idx_per_token->data <int64_t >(),
583
- 127.0 ,
584
- -127.0
585
- );
586
- }
500
+ DISPATCH_NUM_EXPERTS_PER_RANK (num_experts_per_rank, NUM_EXPERTS_PER_RANK,
501
+ permute_x_kernel<data_t , data_t , NUM_EXPERTS_PER_RANK><<<gridx, 512 , 0 , stream>>> (
502
+ input.data <data_t >(),
503
+ topk_ids.data <int64_t >(),
504
+ topk_weights.data <float >(),
505
+ token_nums_per_expert.data <int >(),
506
+ up_gate_proj_in_scale ? up_gate_proj_in_scale.get ().data <float >() : nullptr ,
507
+ moe_topk,
508
+ num_rows,
509
+ token_nums_this_rank,
510
+ hidden_size,
511
+ permute_input->data <data_t >(),
512
+ permute_indices_per_token->data <int >(),
513
+ dst_weights->data <float >(),
514
+ dst_indices->data <int >(),
515
+ cumsum_idx_gpu->data <int >(),
516
+ token_nums_per_expert_cumsum->data <int64_t >(),
517
+ expert_idx_per_token->data <int64_t >(),
518
+ 127.0 ,
519
+ -127.0
520
+ );)
587
521
}
588
522
}
589
523
0 commit comments