Skip to content

Commit bdab888

Browse files
authored
support dispatch both bf16 & fp8 (#10817)
1 parent b5b94e5 commit bdab888

File tree

5 files changed

+316
-181
lines changed

5 files changed

+316
-181
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def __init__(
182182
use_dualpipev=False,
183183
send_mtp_embed=False,
184184
recompute_fwd_gate_up=False,
185-
dequant_input=False,
186185
is_split_group_gemm=False,
187186
**kwargs,
188187
):
@@ -235,10 +234,8 @@ def __init__(
235234
self.use_dualpipev = use_dualpipev
236235
self.send_mtp_embed = send_mtp_embed
237236
self.recompute_fwd_gate_up = recompute_fwd_gate_up
238-
self.dequant_input = dequant_input
239237
self.is_split_group_gemm = is_split_group_gemm
240238

241-
242239
super().__init__(
243240
pad_token_id=pad_token_id,
244241
bos_token_id=bos_token_id,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666

6767
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
68+
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"
6869

6970

7071
def parse_args(args):
@@ -510,7 +511,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
510511
token_probs,
511512
) = inputs
512513

513-
(hs_fp16_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward(
514+
(hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward(
514515
hs_2d,
515516
token_indices,
516517
token_probs,
@@ -524,7 +525,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
524525
hidden_states,
525526
residual,
526527
l_aux,
527-
hs_fp16_dispatched,
528+
hs_dispatched,
528529
dispatched_indices,
529530
dispatched_probs,
530531
)
@@ -533,7 +534,7 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo
533534
hidden_states,
534535
residual,
535536
l_aux,
536-
hs_fp16_dispatched,
537+
hs_dispatched,
537538
dispatched_indices,
538539
dispatched_probs,
539540
)
@@ -545,7 +546,7 @@ def mlp_forward(self, inputs):
545546
hidden_states,
546547
residual,
547548
l_aux,
548-
hs_fp16_dispatched,
549+
hs_dispatched,
549550
dispatched_indices,
550551
dispatched_probs,
551552
) = inputs
@@ -554,12 +555,12 @@ def mlp_forward(self, inputs):
554555
hidden_states,
555556
residual,
556557
l_aux,
557-
hs_fp16_dispatched,
558+
hs_dispatched,
558559
dispatched_indices,
559560
dispatched_probs,
560561
) = inputs
561562
hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward(
562-
hs_fp16_dispatched, dispatched_indices, dispatched_probs
563+
hs_dispatched, dispatched_indices, dispatched_probs
563564
)
564565

565566
if self.send_mtp_embed:
@@ -573,18 +574,18 @@ def combine_forward(self, inputs, async_finish=False):
573574
else:
574575
(hidden_states, residual, l_aux, hidden_states_out) = inputs
575576

576-
output_combie = self.fp8_fusion_moe_node.combine_node.forward(hidden_states_out, async_finish=async_finish)
577+
output_combine = self.fp8_fusion_moe_node.combine_node.forward(hidden_states_out, async_finish=async_finish)
577578
if self.send_mtp_embed:
578-
return (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combie)
579+
return (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine)
579580
else:
580-
return (hidden_states, residual, l_aux, output_combie)
581+
return (hidden_states, residual, l_aux, output_combine)
581582

582583
def post_process_forward(self, inputs):
583584
if self.send_mtp_embed:
584-
(inputs_embeds_mtp, hidden_states, residual, l_aux, output_combie) = inputs
585+
(inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs
585586
else:
586-
(hidden_states, residual, l_aux, output_combie) = inputs
587-
final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combie)
587+
(hidden_states, residual, l_aux, output_combine) = inputs
588+
final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine)
588589
if self.send_mtp_embed:
589590
inputs = (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states)
590591
else:
@@ -609,21 +610,21 @@ def post_process_backward(self, output_grad):
609610
l_aux_grad,
610611
final_hidden_states_grad,
611612
) = self.post_process_node.backward(output_grad)
612-
output_combie_grad = self.fp8_fusion_moe_node.combine_quant_node.backward(final_hidden_states_grad)
613+
output_combine_grad = self.fp8_fusion_moe_node.combine_quant_node.backward(final_hidden_states_grad)
613614
if self.send_mtp_embed:
614615
return (
615616
inputs_embeds_mtp_grad,
616617
hidden_states_grad,
617618
residual_grad,
618619
l_aux_grad,
619-
output_combie_grad,
620+
output_combine_grad,
620621
)
621622
else:
622623
return (
623624
hidden_states_grad,
624625
residual_grad,
625626
l_aux_grad,
626-
output_combie_grad,
627+
output_combine_grad,
627628
)
628629

629630
def combine_backward(self, output_grad, async_finish=False):
@@ -633,18 +634,18 @@ def combine_backward(self, output_grad, async_finish=False):
633634
hidden_states_grad,
634635
residual_grad,
635636
l_aux_grad,
636-
output_combie_grad_bf16,
637+
output_combine_grad,
637638
) = output_grad
638639
else:
639640
(
640641
hidden_states_grad,
641642
residual_grad,
642643
l_aux_grad,
643-
output_combie_grad_bf16,
644+
output_combine_grad,
644645
) = output_grad
645646

646-
hidden_states_out_grad_bf16 = self.fp8_fusion_moe_node.combine_node.backward(
647-
output_combie_grad_bf16,
647+
hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward(
648+
output_combine_grad,
648649
async_finish=async_finish,
649650
)
650651

@@ -654,14 +655,14 @@ def combine_backward(self, output_grad, async_finish=False):
654655
hidden_states_grad,
655656
residual_grad,
656657
l_aux_grad,
657-
hidden_states_out_grad_bf16,
658+
hidden_states_out_grad,
658659
)
659660
else:
660661
return (
661662
hidden_states_grad,
662663
residual_grad,
663664
l_aux_grad,
664-
hidden_states_out_grad_bf16,
665+
hidden_states_out_grad,
665666
)
666667

667668
def mlp_backward(self, output_grad):
@@ -680,25 +681,23 @@ def mlp_backward(self, output_grad):
680681
l_aux_grad,
681682
hidden_states_out_grad,
682683
) = output_grad
683-
hs_fp16_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(
684-
hidden_states_out_grad
685-
)
684+
hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad)
686685

687686
if self.send_mtp_embed:
688687
return (
689688
inputs_embeds_mtp_grad,
690689
hidden_states_grad,
691690
residual_grad,
692691
l_aux_grad,
693-
hs_fp16_dispatched_grad,
692+
hs_dispatched_grad,
694693
dispatched_probs_grad,
695694
)
696695
else:
697696
return (
698697
hidden_states_grad,
699698
residual_grad,
700699
l_aux_grad,
701-
hs_fp16_dispatched_grad,
700+
hs_dispatched_grad,
702701
dispatched_probs_grad,
703702
)
704703

@@ -709,19 +708,19 @@ def dispatch_backward(self, output_grad, async_finish=False):
709708
hidden_states_grad,
710709
residual_grad,
711710
l_aux_grad,
712-
hs_bf16_dispatched_grad,
711+
hs_dispatched_grad,
713712
dispatched_probs_grad,
714713
) = output_grad
715714
else:
716715
(
717716
hidden_states_grad,
718717
residual_grad,
719718
l_aux_grad,
720-
hs_bf16_dispatched_grad,
719+
hs_dispatched_grad,
721720
dispatched_probs_grad,
722721
) = output_grad
723-
hs_bf16_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward(
724-
hs_bf16_dispatched_grad, dispatched_probs_grad, async_finish=async_finish
722+
hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward(
723+
hs_dispatched_grad, dispatched_probs_grad, async_finish=async_finish
725724
)
726725

727726
if self.send_mtp_embed:
@@ -730,11 +729,11 @@ def dispatch_backward(self, output_grad, async_finish=False):
730729
hidden_states_grad,
731730
residual_grad,
732731
l_aux_grad,
733-
hs_bf16_grad,
732+
hs_grad,
734733
token_probs_grad,
735734
)
736735
else:
737-
return (hidden_states_grad, residual_grad, l_aux_grad, hs_bf16_grad, token_probs_grad)
736+
return (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad)
738737

739738
def attn_backward(self, output_grad):
740739
if self.send_mtp_embed:
@@ -743,19 +742,19 @@ def attn_backward(self, output_grad):
743742
hidden_states_grad,
744743
residual_grad,
745744
l_aux_grad,
746-
hs_bf16_grad,
745+
hs_grad,
747746
token_probs_grad,
748747
) = output_grad
749748
else:
750749
(
751750
hidden_states_grad,
752751
residual_grad,
753752
l_aux_grad,
754-
hs_bf16_grad,
753+
hs_grad,
755754
token_probs_grad,
756755
) = output_grad
757756
hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward(
758-
hs_bf16_grad, token_probs_grad
757+
hs_grad, token_probs_grad
759758
)
760759

761760
if self.send_mtp_embed:
@@ -1234,7 +1233,6 @@ def build_schedule_node(self):
12341233
fp8_fusion_moe_node = FusionMoeNode(
12351234
self.mlp,
12361235
recompute_fwd_gate_up=self.config.recompute_fwd_gate_up,
1237-
dequant_input=self.config.dequant_input,
12381236
is_split_group_gemm=self.config.is_split_group_gemm,
12391237
name="fp8_fusion_moe_node",
12401238
)

0 commit comments

Comments
 (0)