65
65
66
66
67
67
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
68
+ DSV3_USE_FP8_DISPATCH = os .getenv ("DSV3_USE_FP8_DISPATCH" , "False" ).lower () == "true"
68
69
69
70
70
71
def parse_args (args ):
@@ -510,7 +511,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
510
511
token_probs ,
511
512
) = inputs
512
513
513
- (hs_fp16_dispatched , dispatched_indices , dispatched_probs ,) = self .fp8_fusion_moe_node .dispatch_node .forward (
514
+ (hs_dispatched , dispatched_indices , dispatched_probs ,) = self .fp8_fusion_moe_node .dispatch_node .forward (
514
515
hs_2d ,
515
516
token_indices ,
516
517
token_probs ,
@@ -524,7 +525,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
524
525
hidden_states ,
525
526
residual ,
526
527
l_aux ,
527
- hs_fp16_dispatched ,
528
+ hs_dispatched ,
528
529
dispatched_indices ,
529
530
dispatched_probs ,
530
531
)
@@ -533,7 +534,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
533
534
hidden_states ,
534
535
residual ,
535
536
l_aux ,
536
- hs_fp16_dispatched ,
537
+ hs_dispatched ,
537
538
dispatched_indices ,
538
539
dispatched_probs ,
539
540
)
@@ -545,7 +546,7 @@ def mlp_forward(self, inputs):
545
546
hidden_states ,
546
547
residual ,
547
548
l_aux ,
548
- hs_fp16_dispatched ,
549
+ hs_dispatched ,
549
550
dispatched_indices ,
550
551
dispatched_probs ,
551
552
) = inputs
@@ -554,12 +555,12 @@ def mlp_forward(self, inputs):
554
555
hidden_states ,
555
556
residual ,
556
557
l_aux ,
557
- hs_fp16_dispatched ,
558
+ hs_dispatched ,
558
559
dispatched_indices ,
559
560
dispatched_probs ,
560
561
) = inputs
561
562
hidden_states_out = self .fp8_fusion_moe_node .mlp_node .forward (
562
- hs_fp16_dispatched , dispatched_indices , dispatched_probs
563
+ hs_dispatched , dispatched_indices , dispatched_probs
563
564
)
564
565
565
566
if self .send_mtp_embed :
@@ -573,18 +574,18 @@ def combine_forward(self, inputs, async_finish=False):
573
574
else :
574
575
(hidden_states , residual , l_aux , hidden_states_out ) = inputs
575
576
576
- output_combie = self .fp8_fusion_moe_node .combine_node .forward (hidden_states_out , async_finish = async_finish )
577
+ output_combine = self .fp8_fusion_moe_node .combine_node .forward (hidden_states_out , async_finish = async_finish )
577
578
if self .send_mtp_embed :
578
- return (inputs_embeds_mtp , hidden_states , residual , l_aux , output_combie )
579
+ return (inputs_embeds_mtp , hidden_states , residual , l_aux , output_combine )
579
580
else :
580
- return (hidden_states , residual , l_aux , output_combie )
581
+ return (hidden_states , residual , l_aux , output_combine )
581
582
582
583
def post_process_forward (self , inputs ):
583
584
if self .send_mtp_embed :
584
- (inputs_embeds_mtp , hidden_states , residual , l_aux , output_combie ) = inputs
585
+ (inputs_embeds_mtp , hidden_states , residual , l_aux , output_combine ) = inputs
585
586
else :
586
- (hidden_states , residual , l_aux , output_combie ) = inputs
587
- final_hidden_states = self .fp8_fusion_moe_node .combine_quant_node .forward (output_combie )
587
+ (hidden_states , residual , l_aux , output_combine ) = inputs
588
+ final_hidden_states = self .fp8_fusion_moe_node .combine_quant_node .forward (output_combine )
588
589
if self .send_mtp_embed :
589
590
inputs = (inputs_embeds_mtp , hidden_states , residual , l_aux , final_hidden_states )
590
591
else :
@@ -609,21 +610,21 @@ def post_process_backward(self, output_grad):
609
610
l_aux_grad ,
610
611
final_hidden_states_grad ,
611
612
) = self .post_process_node .backward (output_grad )
612
- output_combie_grad = self .fp8_fusion_moe_node .combine_quant_node .backward (final_hidden_states_grad )
613
+ output_combine_grad = self .fp8_fusion_moe_node .combine_quant_node .backward (final_hidden_states_grad )
613
614
if self .send_mtp_embed :
614
615
return (
615
616
inputs_embeds_mtp_grad ,
616
617
hidden_states_grad ,
617
618
residual_grad ,
618
619
l_aux_grad ,
619
- output_combie_grad ,
620
+ output_combine_grad ,
620
621
)
621
622
else :
622
623
return (
623
624
hidden_states_grad ,
624
625
residual_grad ,
625
626
l_aux_grad ,
626
- output_combie_grad ,
627
+ output_combine_grad ,
627
628
)
628
629
629
630
def combine_backward (self , output_grad , async_finish = False ):
@@ -633,18 +634,18 @@ def combine_backward(self, output_grad, async_finish=False):
633
634
hidden_states_grad ,
634
635
residual_grad ,
635
636
l_aux_grad ,
636
- output_combie_grad_bf16 ,
637
+ output_combine_grad ,
637
638
) = output_grad
638
639
else :
639
640
(
640
641
hidden_states_grad ,
641
642
residual_grad ,
642
643
l_aux_grad ,
643
- output_combie_grad_bf16 ,
644
+ output_combine_grad ,
644
645
) = output_grad
645
646
646
- hidden_states_out_grad_bf16 = self .fp8_fusion_moe_node .combine_node .backward (
647
- output_combie_grad_bf16 ,
647
+ hidden_states_out_grad = self .fp8_fusion_moe_node .combine_node .backward (
648
+ output_combine_grad ,
648
649
async_finish = async_finish ,
649
650
)
650
651
@@ -654,14 +655,14 @@ def combine_backward(self, output_grad, async_finish=False):
654
655
hidden_states_grad ,
655
656
residual_grad ,
656
657
l_aux_grad ,
657
- hidden_states_out_grad_bf16 ,
658
+ hidden_states_out_grad ,
658
659
)
659
660
else :
660
661
return (
661
662
hidden_states_grad ,
662
663
residual_grad ,
663
664
l_aux_grad ,
664
- hidden_states_out_grad_bf16 ,
665
+ hidden_states_out_grad ,
665
666
)
666
667
667
668
def mlp_backward (self , output_grad ):
@@ -680,25 +681,23 @@ def mlp_backward(self, output_grad):
680
681
l_aux_grad ,
681
682
hidden_states_out_grad ,
682
683
) = output_grad
683
- hs_fp16_dispatched_grad , dispatched_probs_grad = self .fp8_fusion_moe_node .mlp_node .backward (
684
- hidden_states_out_grad
685
- )
684
+ hs_dispatched_grad , dispatched_probs_grad = self .fp8_fusion_moe_node .mlp_node .backward (hidden_states_out_grad )
686
685
687
686
if self .send_mtp_embed :
688
687
return (
689
688
inputs_embeds_mtp_grad ,
690
689
hidden_states_grad ,
691
690
residual_grad ,
692
691
l_aux_grad ,
693
- hs_fp16_dispatched_grad ,
692
+ hs_dispatched_grad ,
694
693
dispatched_probs_grad ,
695
694
)
696
695
else :
697
696
return (
698
697
hidden_states_grad ,
699
698
residual_grad ,
700
699
l_aux_grad ,
701
- hs_fp16_dispatched_grad ,
700
+ hs_dispatched_grad ,
702
701
dispatched_probs_grad ,
703
702
)
704
703
@@ -709,19 +708,19 @@ def dispatch_backward(self, output_grad, async_finish=False):
709
708
hidden_states_grad ,
710
709
residual_grad ,
711
710
l_aux_grad ,
712
- hs_bf16_dispatched_grad ,
711
+ hs_dispatched_grad ,
713
712
dispatched_probs_grad ,
714
713
) = output_grad
715
714
else :
716
715
(
717
716
hidden_states_grad ,
718
717
residual_grad ,
719
718
l_aux_grad ,
720
- hs_bf16_dispatched_grad ,
719
+ hs_dispatched_grad ,
721
720
dispatched_probs_grad ,
722
721
) = output_grad
723
- hs_bf16_grad , token_probs_grad = self .fp8_fusion_moe_node .dispatch_node .backward (
724
- hs_bf16_dispatched_grad , dispatched_probs_grad , async_finish = async_finish
722
+ hs_grad , token_probs_grad = self .fp8_fusion_moe_node .dispatch_node .backward (
723
+ hs_dispatched_grad , dispatched_probs_grad , async_finish = async_finish
725
724
)
726
725
727
726
if self .send_mtp_embed :
@@ -730,11 +729,11 @@ def dispatch_backward(self, output_grad, async_finish=False):
730
729
hidden_states_grad ,
731
730
residual_grad ,
732
731
l_aux_grad ,
733
- hs_bf16_grad ,
732
+ hs_grad ,
734
733
token_probs_grad ,
735
734
)
736
735
else :
737
- return (hidden_states_grad , residual_grad , l_aux_grad , hs_bf16_grad , token_probs_grad )
736
+ return (hidden_states_grad , residual_grad , l_aux_grad , hs_grad , token_probs_grad )
738
737
739
738
def attn_backward (self , output_grad ):
740
739
if self .send_mtp_embed :
@@ -743,19 +742,19 @@ def attn_backward(self, output_grad):
743
742
hidden_states_grad ,
744
743
residual_grad ,
745
744
l_aux_grad ,
746
- hs_bf16_grad ,
745
+ hs_grad ,
747
746
token_probs_grad ,
748
747
) = output_grad
749
748
else :
750
749
(
751
750
hidden_states_grad ,
752
751
residual_grad ,
753
752
l_aux_grad ,
754
- hs_bf16_grad ,
753
+ hs_grad ,
755
754
token_probs_grad ,
756
755
) = output_grad
757
756
hidden_states_grad_ , probs_grad , routing_map_grad = self .fp8_fusion_moe_node .dispatch_quant_node .backward (
758
- hs_bf16_grad , token_probs_grad
757
+ hs_grad , token_probs_grad
759
758
)
760
759
761
760
if self .send_mtp_embed :
@@ -1234,7 +1233,6 @@ def build_schedule_node(self):
1234
1233
fp8_fusion_moe_node = FusionMoeNode (
1235
1234
self .mlp ,
1236
1235
recompute_fwd_gate_up = self .config .recompute_fwd_gate_up ,
1237
- dequant_input = self .config .dequant_input ,
1238
1236
is_split_group_gemm = self .config .is_split_group_gemm ,
1239
1237
name = "fp8_fusion_moe_node" ,
1240
1238
)
0 commit comments