Skip to content

Commit 6e6b373

Browse files
CE bug fix (#10999)
* CE bug fix * Update trainer_callback.py * CE bug fix * Update modeling.py --------- Co-authored-by: zhangbo9674 <[email protected]>
1 parent 7ce548e commit 6e6b373

File tree

4 files changed

+27
-39
lines changed

4 files changed

+27
-39
lines changed

paddlenlp/trainer/trainer_callback.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def on_step_begin(self, args, state, control, **kwargs):
644644
optimizer = kwargs["optimizer"]
645645
global skip_count
646646

647-
if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
647+
if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"):
648648
model.fp8_quant_weight(True, quant_transpose=False)
649649
optimizer.clear_param_storage("moe_expert")
650650
optimizer.clear_param_storage("rms_linear")
@@ -664,6 +664,10 @@ def on_step_begin(self, args, state, control, **kwargs):
664664
skip_count += 1
665665

666666
def on_optimizer_begin(self, args, state, control, **kwargs):
667+
model = kwargs["model"]
667668
optimizer = kwargs["optimizer"]
668-
for name in self.moe_weights_name:
669-
reload(optimizer._master_weights[name])
669+
global skip_count
670+
671+
if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"):
672+
for name in self.moe_weights_name:
673+
reload(optimizer._master_weights[name])

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,10 +1538,7 @@ def backward(ctx, dout):
15381538
else:
15391539
assert False, f"invalid {FA_VERSION=}"
15401540

1541-
if FA_VERSION == 2:
1542-
assert not recompute_fa3
1543-
assert attn_out is not None and softmax_lse is not None
1544-
if FA_VERSION == 3 and not recompute_fa3:
1541+
if (FA_VERSION == 3 and not recompute_fa3) or FA_VERSION == 2:
15451542
assert attn_out is not None and softmax_lse is not None
15461543

15471544
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11181118
combine_forward_event.current_stream_wait()
11191119
final_out_event.current_stream_wait()
11201120

1121-
inputs = final_out + combine_fwd_out
1121+
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
1122+
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
1123+
else:
1124+
final_out += combine_fwd_out
1125+
inputs = final_out
11221126

11231127
final_out._record_stream()
11241128
combine_fwd_out._record_stream()

paddlenlp/transformers/moe_layer.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -904,24 +904,21 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
904904
self.dispatched_indices = dispatched_indices.to(paddle.int32)
905905

906906
total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0]
907-
if DSV3_USE_FP8_DISPATCH:
908-
(
909-
unzipped_tokens,
910-
zipped_expertwise_rowmap,
911-
unzipped_probs,
912-
unzipped_tokens_scale,
913-
) = self.unzip_node.forward(
914-
hs_2d_dispatched,
915-
self.dispatched_indices,
916-
dispatched_probs,
917-
topk=self.router_topk,
918-
num_experts=num_experts,
919-
tokens_per_expert=self.tokens_per_expert,
920-
)
921-
record_stream_for_multi_input(hs_2d_dispatched)
922-
dispatched_indices._record_stream()
923-
dispatched_probs._record_stream()
907+
(unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_tokens_scale,) = self.unzip_node.forward(
908+
hs_2d_dispatched,
909+
self.dispatched_indices,
910+
dispatched_probs,
911+
topk=self.router_topk,
912+
num_experts=num_experts,
913+
tokens_per_expert=self.tokens_per_expert,
914+
)
915+
record_stream_for_multi_input(hs_2d_dispatched)
916+
dispatched_indices._record_stream()
917+
dispatched_probs._record_stream()
918+
919+
self.unzipped_probs = unzipped_probs.unsqueeze(-1)
924920

921+
if DSV3_USE_FP8_DISPATCH:
925922
total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0]
926923
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
927924
if self.recompute_fwd_gate_up == -1:
@@ -935,8 +932,6 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
935932
# logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.")
936933
self.set_recompute_fwd_gate_up(False)
937934

938-
self.unzipped_probs = unzipped_probs.unsqueeze(-1)
939-
940935
# if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches
941936
if self.mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_fwd_subbatch_rows * 2:
942937
assert (
@@ -990,18 +985,6 @@ def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
990985
(unzipped_tokens, unzipped_tokens_scale), unzipped_probs, padding_token_per_experts
991986
)
992987
else:
993-
(unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, _,) = self.unzip_node.forward(
994-
hs_2d_dispatched,
995-
self.dispatched_indices,
996-
dispatched_probs,
997-
topk=self.router_topk,
998-
num_experts=num_experts,
999-
tokens_per_expert=self.tokens_per_expert,
1000-
)
1001-
hs_2d_dispatched._record_stream()
1002-
dispatched_indices._record_stream()
1003-
dispatched_probs._record_stream()
1004-
1005988
# If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance
1006989
if self.recompute_fwd_gate_up == -1:
1007990
if (

0 commit comments

Comments
 (0)