Skip to content

Commit e63b1e7

Browse files
authored
fix (#10977)
1 parent 740f471 commit e63b1e7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/transformers/moe_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,8 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
900900
self.padding_token_per_experts = padding_token_per_experts
901901
# 1 unzip
902902
self.dispatched_indices = dispatched_indices.to(paddle.int32)
903+
904+
total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0]
903905
if DSV3_USE_FP8_DISPATCH:
904906
(
905907
unzipped_tokens,
@@ -919,8 +921,6 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
919921
dispatched_probs._record_stream()
920922

921923
total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0]
922-
total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0]
923-
924924
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
925925
if self.recompute_fwd_gate_up == -1:
926926
if (

0 commit comments

Comments
 (0)