@@ -344,9 +344,11 @@ void FusedGateMoeKernel(
344344 const phi::DenseTensor& hidden_states,
345345 const phi::DenseTensor& gate_out,
346346 const paddle::optional<phi::DenseTensor>& gate_correction_bias,
347- const std::vector<phi::DenseTensor>& gate_up_weights,
348- const std::vector<phi::DenseTensor>& down_weights,
349- const paddle::optional<std::vector<phi::DenseTensor>>& scales,
347+ const phi::DenseTensor& gate_up_weights,
348+ const phi::DenseTensor& down_weights,
349+ const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
350+ const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
351+ const paddle::optional<phi::DenseTensor>& down_weights_scales,
350352 phi::DenseTensor* final_hidden_states,
351353 const int top_k,
352354 const bool moe_use_gate_correction_bias,
@@ -358,14 +360,20 @@ void FusedGateMoeKernel(
358360 const bool measurement_mode,
359361 const bool dynamic_scale,
360362 const int block_size) {
363+ std::vector<int64_t > gate_up_weights_dims =
364+ phi::vectorize<int64_t >(gate_up_weights.dims ());
365+ std::vector<int64_t > down_weights_dims =
366+ phi::vectorize<int64_t >(down_weights.dims ());
361367 FusedGateMoeParams params;
362368 memset (reinterpret_cast <void *>(¶ms), 0x00 , sizeof (FusedGateMoeParams));
363369 params.topk = top_k;
364370 params.moe_use_gate_correction_bias = moe_use_gate_correction_bias;
365371 params.norm_topk_prob = norm_topk_prob;
366372 params.permuted_weights = permuted_weights;
367- params.fused_gemm = (gate_up_weights.size () == down_weights.size ());
368- params.num_experts = down_weights.size ();
373+ // TODO(yanfeich): add optional up_weights
374+ params.fused_gemm = (gate_up_weights_dims[2 ] == down_weights_dims[1 ] * 2 );
375+ params.measurement_mode = measurement_mode;
376+ params.num_experts = gate_up_weights_dims[0 ];
369377 params.experts_min = experts_min;
370378 params.experts_max = experts_max;
371379 params.dynamic_scale = dynamic_scale;
@@ -380,16 +388,18 @@ void FusedGateMoeKernel(
380388 if (moe_use_gate_correction_bias) {
381389 ct.Add (gate_correction_bias.get ());
382390 }
383- for (const auto & t : gate_up_weights) {
384- ct.Add (t);
391+
392+ ct.AddN (gate_up_weights);
393+ ct.AddN (down_weights);
394+
395+ if (intermediate_hidden_states_scales) {
396+ ct.AddN (intermediate_hidden_states_scales.get ());
385397 }
386- for ( const auto & t : down_weights ) {
387- ct.Add (t );
398+ if (gate_up_weights_scales ) {
399+ ct.AddN (gate_up_weights_scales. get () );
388400 }
389- if (scales) {
390- for (const auto & t : scales.get ()) {
391- ct.Add (t);
392- }
401+ if (down_weights_scales) {
402+ ct.AddN (down_weights_scales.get ());
393403 }
394404
395405 ct.Add (*final_hidden_states, false );
@@ -424,9 +434,11 @@ void CallFusedGateMoeKernel(
424434 const phi::DenseTensor& hidden_states,
425435 const phi::DenseTensor& gate_out,
426436 const paddle::optional<phi::DenseTensor>& gate_correction_bias,
427- const std::vector<phi::DenseTensor>& gate_up_weights,
428- const std::vector<phi::DenseTensor>& down_weights,
429- const paddle::optional<std::vector<phi::DenseTensor>>& scales,
437+ const phi::DenseTensor& gate_up_weights,
438+ const phi::DenseTensor& down_weights,
439+ const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
440+ const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
441+ const paddle::optional<phi::DenseTensor>& down_weights_scales,
430442 phi::DenseTensor* final_hidden_states,
431443 const int top_k,
432444 const bool moe_use_gate_correction_bias,
@@ -450,7 +462,9 @@ void CallFusedGateMoeKernel(
450462 gate_correction_bias,
451463 gate_up_weights,
452464 down_weights,
453- scales,
465+ intermediate_hidden_states_scales,
466+ gate_up_weights_scales,
467+ down_weights_scales,
454468 final_hidden_states,
455469 top_k,
456470 moe_use_gate_correction_bias,
@@ -471,7 +485,9 @@ void CallFusedGateMoeKernel(
471485 gate_correction_bias,
472486 gate_up_weights,
473487 down_weights,
474- scales,
488+ intermediate_hidden_states_scales,
489+ gate_up_weights_scales,
490+ down_weights_scales,
475491 final_hidden_states,
476492 top_k,
477493 moe_use_gate_correction_bias,
@@ -493,8 +509,8 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
493509 const paddle::Tensor& hidden_states,
494510 const paddle::Tensor& gate_out,
495511 const paddle::optional<paddle::Tensor>& gate_correction_bias,
496- const std::vector< paddle::Tensor> & gate_up_weights,
497- const std::vector< paddle::Tensor> & down_weights,
512+ const paddle::Tensor& gate_up_weights,
513+ const paddle::Tensor& down_weights,
498514 const int top_k,
499515 const bool moe_use_gate_correction_bias,
500516 const bool norm_topk_prob,
@@ -519,16 +535,10 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
519535 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
520536 }
521537
522- std::vector<phi::DenseTensor> gate_up_weights_vec;
523- for (const auto & t : gate_up_weights) {
524- gate_up_weights_vec.push_back (
525- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
526- }
527- std::vector<phi::DenseTensor> down_weights_vec;
528- for (const auto & t : down_weights) {
529- down_weights_vec.push_back (
530- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
531- }
538+ auto gate_up_weights_tensor =
539+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
540+ auto down_weights_tensor =
541+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
532542
533543 std::shared_ptr<phi::DenseTensor> final_hidden_states =
534544 std::make_shared<phi::DenseTensor>();
@@ -540,9 +550,11 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
540550 *hidden_states_tensor,
541551 *gate_out_tensor,
542552 gate_correction_tensor,
543- gate_up_weights_vec,
544- down_weights_vec,
545- paddle::optional<std::vector<phi::DenseTensor>>(), /* scales */
553+ *gate_up_weights_tensor,
554+ *down_weights_tensor,
555+ paddle::optional<phi::DenseTensor>(), /* int..hid..st.._scales_tensor */
556+ paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales_tensor */
557+ paddle::optional<phi::DenseTensor>(), /* down_weights_scales_tensor */
546558 final_hidden_states.get (),
547559 top_k,
548560 moe_use_gate_correction_bias,
@@ -563,12 +575,11 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
563575 const paddle::Tensor& hidden_states,
564576 const paddle::Tensor& gate_out,
565577 const paddle::optional<paddle::Tensor>& gate_correction_bias,
566- const std::vector<paddle::Tensor>& gate_up_weights,
567- const std::vector<paddle::Tensor>& down_weights,
568- const paddle::optional<std::vector<paddle::Tensor>>&
569- intermediate_hidden_states_scales,
570- const std::vector<paddle::Tensor>& gate_up_weights_scales,
571- const std::vector<paddle::Tensor>& down_weights_scales,
578+ const paddle::Tensor& gate_up_weights,
579+ const paddle::Tensor& down_weights,
580+ const paddle::optional<paddle::Tensor>& intermediate_hidden_states_scales,
581+ const paddle::Tensor& gate_up_weights_scales,
582+ const paddle::Tensor& down_weights_scales,
572583 const int top_k,
573584 const bool moe_use_gate_correction_bias,
574585 const bool norm_topk_prob,
@@ -593,33 +604,35 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
593604 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
594605 }
595606
596- std::vector<phi::DenseTensor> gate_up_weights_vec;
597- for (const auto & t : gate_up_weights) {
598- gate_up_weights_vec.push_back (
599- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
600- }
601- std::vector<phi::DenseTensor> down_weights_vec;
602- for (const auto & t : down_weights) {
603- down_weights_vec.push_back (
604- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
605- }
607+ auto gate_up_weights_tensor =
608+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
609+ auto down_weights_tensor =
610+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
606611
607612 bool dynamic_scale = true ;
608- std::vector<phi::DenseTensor> scales_vec;
613+ auto intermediate_hidden_states_scales_tensor =
614+ paddle::optional<phi::DenseTensor>();
609615 if (intermediate_hidden_states_scales) {
610616 dynamic_scale = false ;
611- for (const auto & t : intermediate_hidden_states_scales.get ()) {
612- scales_vec.push_back (
613- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
614- }
615- }
616- for (const auto & t : gate_up_weights_scales) {
617- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
618- }
619- for (const auto & t : down_weights_scales) {
620- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
617+ auto intermediate_hidden_states_scales_dt = static_cast <phi::DenseTensor*>(
618+ intermediate_hidden_states_scales->impl ().get ());
619+ intermediate_hidden_states_scales_tensor =
620+ paddle::optional<phi::DenseTensor>(
621+ *intermediate_hidden_states_scales_dt);
621622 }
622623
624+ auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
625+ auto gate_up_weights_scales_dt =
626+ static_cast <const phi::DenseTensor*>(gate_up_weights_scales.impl ().get ());
627+ gate_up_weights_scales_tensor =
628+ paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
629+
630+ auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
631+ auto down_weights_scales_dt =
632+ static_cast <const phi::DenseTensor*>(down_weights_scales.impl ().get ());
633+ down_weights_scales_tensor =
634+ paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
635+
623636 std::shared_ptr<phi::DenseTensor> final_hidden_states =
624637 std::make_shared<phi::DenseTensor>();
625638 final_hidden_states->Resize (hidden_states.dims ());
@@ -630,9 +643,11 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
630643 *hidden_states_tensor,
631644 *gate_out_tensor,
632645 gate_correction_tensor,
633- gate_up_weights_vec,
634- down_weights_vec,
635- scales_vec,
646+ *gate_up_weights_tensor,
647+ *down_weights_tensor,
648+ intermediate_hidden_states_scales_tensor,
649+ gate_up_weights_scales_tensor,
650+ down_weights_scales_tensor,
636651 final_hidden_states.get (),
637652 top_k,
638653 moe_use_gate_correction_bias,
@@ -652,10 +667,10 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
652667 const paddle::Tensor& hidden_states,
653668 const paddle::Tensor& gate_out,
654669 const paddle::optional<paddle::Tensor>& gate_correction_bias,
655- const std::vector< paddle::Tensor> & gate_up_weights,
656- const std::vector< paddle::Tensor> & down_weights,
657- const std::vector< paddle::Tensor> & gate_up_weights_scales,
658- const std::vector< paddle::Tensor> & down_weights_scales,
670+ const paddle::Tensor& gate_up_weights,
671+ const paddle::Tensor& down_weights,
672+ const paddle::Tensor& gate_up_weights_scales,
673+ const paddle::Tensor& down_weights_scales,
659674 const int top_k,
660675 const bool moe_use_gate_correction_bias,
661676 const bool norm_topk_prob,
@@ -681,24 +696,22 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
681696 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
682697 }
683698
684- std::vector<phi::DenseTensor> gate_up_weights_vec;
685- for (const auto & t : gate_up_weights) {
686- gate_up_weights_vec.push_back (
687- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
688- }
689- std::vector<phi::DenseTensor> down_weights_vec;
690- for (const auto & t : down_weights) {
691- down_weights_vec.push_back (
692- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
693- }
699+ auto gate_up_weights_tensor =
700+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
701+ auto down_weights_tensor =
702+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
694703
695- std::vector<phi::DenseTensor> scales_vec;
696- for (const auto & t : gate_up_weights_scales) {
697- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
698- }
699- for (const auto & t : down_weights_scales) {
700- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
701- }
704+ auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
705+ auto gate_up_weights_scales_dt =
706+ static_cast <const phi::DenseTensor*>(gate_up_weights_scales.impl ().get ());
707+ gate_up_weights_scales_tensor =
708+ paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
709+
710+ auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
711+ auto down_weights_scales_dt =
712+ static_cast <const phi::DenseTensor*>(down_weights_scales.impl ().get ());
713+ down_weights_scales_tensor =
714+ paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
702715
703716 std::shared_ptr<phi::DenseTensor> final_hidden_states =
704717 std::make_shared<phi::DenseTensor>();
@@ -710,9 +723,11 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
710723 *hidden_states_tensor,
711724 *gate_out_tensor,
712725 gate_correction_tensor,
713- gate_up_weights_vec,
714- down_weights_vec,
715- scales_vec,
726+ *gate_up_weights_tensor,
727+ *down_weights_tensor,
728+ paddle::optional<phi::DenseTensor>(), /* scales */
729+ gate_up_weights_scales_tensor,
730+ down_weights_scales_tensor,
716731 final_hidden_states.get (),
717732 top_k,
718733 moe_use_gate_correction_bias,
@@ -755,8 +770,8 @@ PD_BUILD_OP(fused_gate_moe)
755770 .Inputs({" hidden_states" ,
756771 " gate_out" ,
757772 paddle::Optional (" gate_correction_bias" ),
758- paddle::Vec ( " gate_up_weights" ) ,
759- paddle::Vec ( " down_weights" ) })
773+ " gate_up_weights" ,
774+ " down_weights" })
760775 .Outputs({" final_hidden_states" })
761776 .Attrs({" top_k: int" ,
762777 " moe_use_gate_correction_bias: bool" ,
@@ -780,11 +795,11 @@ PD_BUILD_OP(fused_gate_moe_fp8)
780795 .Inputs({" hidden_states" ,
781796 " gate_out" ,
782797 paddle::Optional (" gate_correction_bias" ),
783- paddle::Vec ( " gate_up_weights" ) ,
784- paddle::Vec ( " down_weights" ) ,
798+ " gate_up_weights" ,
799+ " down_weights" ,
785800 paddle::Optional (paddle::Vec (" intermediate_hidden_states_scales" )),
786- paddle::Vec ( " gate_up_weights_scales" ) ,
787- paddle::Vec ( " down_weights_scales" ) })
801+ " gate_up_weights_scales" ,
802+ " down_weights_scales" })
788803 .Outputs({" final_hidden_states" })
789804 .Attrs({" top_k: int" ,
790805 " moe_use_gate_correction_bias: bool" ,
@@ -807,10 +822,10 @@ PD_BUILD_OP(fused_gate_moe_blockwise_fp8)
807822 .Inputs({" hidden_states" ,
808823 " gate_out" ,
809824 paddle::Optional (" gate_correction_bias" ),
810- paddle::Vec ( " gate_up_weights" ) ,
811- paddle::Vec ( " down_weights" ) ,
812- paddle::Vec ( " gate_up_weights_scales" ) ,
813- paddle::Vec ( " down_weights_scales" ) })
825+ " gate_up_weights" ,
826+ " down_weights" ,
827+ " gate_up_weights_scales" ,
828+ " down_weights_scales" })
814829 .Outputs({" final_hidden_states" })
815830 .Attrs({" top_k: int" ,
816831 " moe_use_gate_correction_bias: bool" ,
0 commit comments