@@ -904,24 +904,21 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
904
904
self .dispatched_indices = dispatched_indices .to (paddle .int32 )
905
905
906
906
total_zipped_tokens = extract_first_if_tuple (hs_2d_dispatched ).shape [0 ]
907
- if DSV3_USE_FP8_DISPATCH :
908
- (
909
- unzipped_tokens ,
910
- zipped_expertwise_rowmap ,
911
- unzipped_probs ,
912
- unzipped_tokens_scale ,
913
- ) = self .unzip_node .forward (
914
- hs_2d_dispatched ,
915
- self .dispatched_indices ,
916
- dispatched_probs ,
917
- topk = self .router_topk ,
918
- num_experts = num_experts ,
919
- tokens_per_expert = self .tokens_per_expert ,
920
- )
921
- record_stream_for_multi_input (hs_2d_dispatched )
922
- dispatched_indices ._record_stream ()
923
- dispatched_probs ._record_stream ()
907
+ (unzipped_tokens , zipped_expertwise_rowmap , unzipped_probs , unzipped_tokens_scale ,) = self .unzip_node .forward (
908
+ hs_2d_dispatched ,
909
+ self .dispatched_indices ,
910
+ dispatched_probs ,
911
+ topk = self .router_topk ,
912
+ num_experts = num_experts ,
913
+ tokens_per_expert = self .tokens_per_expert ,
914
+ )
915
+ record_stream_for_multi_input (hs_2d_dispatched )
916
+ dispatched_indices ._record_stream ()
917
+ dispatched_probs ._record_stream ()
918
+
919
+ self .unzipped_probs = unzipped_probs .unsqueeze (- 1 )
924
920
921
+ if DSV3_USE_FP8_DISPATCH :
925
922
total_unzipped_tokens = extract_first_if_tuple (unzipped_tokens ).shape [0 ]
926
923
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
927
924
if self .recompute_fwd_gate_up == - 1 :
@@ -935,8 +932,6 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
935
932
# logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.")
936
933
self .set_recompute_fwd_gate_up (False )
937
934
938
- self .unzipped_probs = unzipped_probs .unsqueeze (- 1 )
939
-
940
935
# if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches
941
936
if self .mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self .mlp_fwd_subbatch_rows * 2 :
942
937
assert (
@@ -990,18 +985,6 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
990
985
(unzipped_tokens , unzipped_tokens_scale ), unzipped_probs , padding_token_per_experts
991
986
)
992
987
else :
993
- (unzipped_tokens , zipped_expertwise_rowmap , unzipped_probs , _ ,) = self .unzip_node .forward (
994
- hs_2d_dispatched ,
995
- self .dispatched_indices ,
996
- dispatched_probs ,
997
- topk = self .router_topk ,
998
- num_experts = num_experts ,
999
- tokens_per_expert = self .tokens_per_expert ,
1000
- )
1001
- hs_2d_dispatched ._record_stream ()
1002
- dispatched_indices ._record_stream ()
1003
- dispatched_probs ._record_stream ()
1004
-
1005
988
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
1006
989
if self .recompute_fwd_gate_up == - 1 :
1007
990
if (
0 commit comments