Skip to content

Commit 0905f8c

Browse files
committed
doc(comment): fix code comment
1 parent 1b6c1f4 commit 0905f8c

File tree

5 files changed

+78
-77
lines changed

5 files changed

+78
-77
lines changed

paddleformers/trainer/trainer_callback.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,12 @@ def enable_in_dict_config(config, key):
641641

642642
class FP8QuantWeightCallback(TrainerCallback):
643643
"""
644-
FP8QuantWeightCallback
644+
Callback for FP8 weight quantization during training
645645
"""
646646

647647
def on_step_begin(self, args, state, control, **kwargs):
648648
"""
649-
每个step开始前把专家参数quant成fp8q
649+
Quantize expert weights to FP8 before each training step
650650
"""
651651
model = kwargs["model"]
652652
optimizer = kwargs["optimizer"]
@@ -672,6 +672,9 @@ def on_step_begin(self, args, state, control, **kwargs):
672672
skip_count += 1
673673

674674
def on_optimizer_begin(self, args, state, control, **kwargs):
675+
"""
676+
Reload weights before optimizer step
677+
"""
675678
model = kwargs["model"]
676679
optimizer = kwargs["optimizer"]
677680
global skip_count

paddleformers/transformers/deepseek_v2/modeling_fast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,7 @@ def forward(self, hidden_states):
297297

298298
# Compute all possible return values
299299
if self.using_flex_token:
300-
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(
301-
scores
302-
) # (scores, routing_map, exp_counts, l_aux, l_zloss)
300+
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores)
303301
ret = (scores, routing_map, l_aux, l_zloss)
304302
else:
305303
ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)
@@ -1183,7 +1181,6 @@ def get_tensor_parallel_split_mappings(num_layers):
11831181
base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True)
11841182

11851183
# if we have enough num_key_value_heads to split, then split it.
1186-
# ???
11871184
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
11881185
base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True)
11891186
if config.use_fp8:

paddleformers/transformers/deepseek_v2/modeling_pp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def forward_without_residual(self, inputs):
227227
if self.send_mtp_embed:
228228
assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first"
229229
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
230-
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
230+
self.mtp_embed_shape = (
231+
inputs_embeds_mtp.shape
232+
) # Save the shape of mtp_embed, used for backward propagation
231233

232234
return return_args(hidden_states)
233235

@@ -270,7 +272,9 @@ def forward(self, inputs):
270272
hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1)
271273
else:
272274
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
273-
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
275+
self.mtp_embed_shape = (
276+
inputs_embeds_mtp.shape
277+
) # Save the shape of mtp_embed shape, used for backward propagation
274278

275279
return return_args(hidden_states)
276280

@@ -279,7 +283,7 @@ def backward(self, output_grad):
279283
(do3,) = output_grad
280284

281285
if self.send_mtp_embed:
282-
# 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
286+
# Split gradient: first part of do3 corresponds to hidden_states, second part corresponds to inputs_embeds_mtp
283287
hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1]
284288
if self.output_mtp_embed_first:
285289
hidden_states_grad = do3[..., hidden_size:]
@@ -545,7 +549,6 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
545549
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))
546550

547551
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
548-
# print(" fwd pp stream", pp_stream)
549552
event_to_wait = combine_bw_event_to_wait
550553
for i, n in enumerate(self.nodes):
551554
pp_stream_t = pp_stream
@@ -1146,7 +1149,6 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
11461149

11471150
paddle.base.core.nvprof_nvtx_push("combine_backward")
11481151
if combine_bw_event_to_wait is not None:
1149-
# print(" event", combine_bw_event_to_wait)
11501152
output_grad = self.backward_node.combine_backward(
11511153
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
11521154
)
@@ -1223,7 +1225,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12231225

12241226
# TODO: check correct
12251227
# if final_out.shape[-1] != combine_fwd_out.shape[-1]:
1226-
# final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
1228+
# final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # Directly broadcast and add
12271229
# else:
12281230
# final_out += combine_fwd_out
12291231
inputs = final_out + combine_fwd_out
@@ -1257,7 +1259,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12571259

12581260
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
12591261
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
1260-
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
1262+
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out
12611263
else:
12621264
final_out += combine_fwd_out
12631265
inputs = final_out
@@ -1813,7 +1815,7 @@ def build_schedule_node(self):
18131815
if DSV3_USE_FP8_GEMM:
18141816
attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node")
18151817

1816-
# recompute_fwd_gate_up_ may be 1, 0 or -1, 1 means recompute, 0 means disable recompute, -1 means adaptive recompute.
1818+
# recompute_fwd_gate_up_ may be 1, 0 or -1. 1 means recompute, 0 means disable recompute, -1 means adaptive recompute.
18171819
recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0
18181820
if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio:
18191821
recompute_fwd_gate_up_ = -1

paddleformers/transformers/fp8_utils.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,23 @@ def compute_fp8_linear(
217217
input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only", *, out=None
218218
):
219219
"""
220-
FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。
220+
FP8 Linear computation function supporting multiple return modes and quantized/unquantized inputs.
221221
222222
Args:
223-
input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。
224-
weight: 权重张量。
225-
weight_transpose (bool): 是否转置权重。
226-
return_transpose_only (bool): 是否仅返回转置后的权重。
227-
return_mode (str): 返回模式,可选:
228-
- "output_only": 仅返回输出张量。
229-
- "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。
230-
- "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale).
223+
input: Input tensor (raw tensor or quantized as (input_fp8, input_scale) tuple)
224+
weight: Weight tensor
225+
weight_transpose (bool): Whether to transpose weight
226+
return_transpose_only (bool): Whether to return only transposed weight
227+
return_mode (str): Return mode options:
228+
- "output_only": Returns only output tensor
229+
- "with_input_quant": Returns output + input quant results (input_fp8, input_scale)
230+
- "with_input_transpose_quant": Returns output + transposed quant results (input_t_fp8, input_t_scale)
231+
231232
Returns:
232-
根据 return_mode 返回不同组合的张量。
233+
Different combinations of tensors based on return_mode
233234
234235
Raises:
235-
RuntimeError: 如果 return_mode 不支持。
236+
RuntimeError: If return_mode is not supported
236237
"""
237238
# check input
238239
is_input_quantized = isinstance(input, (tuple, list)) and len(input) == 2
@@ -294,7 +295,7 @@ def compute_expert_w_grad(
294295
rtn_dtype=paddle.bfloat16,
295296
):
296297
"""
297-
统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad)
298+
Unified gradient computation for expert_w weights (supports both main_grad and regular grad).
298299
"""
299300

300301
if input_t is None or numpy.prod(input_t.shape) == 0:
@@ -352,22 +353,22 @@ def common_fp8_mlp_bwd(
352353
if x_fp8 is None or x_scale is None:
353354
raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.")
354355

355-
# # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) =====
356+
# [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8)
356357

357358
# Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
358359
w1_fp8, w1_scale = weight_quant(w1, True)
359360
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
360361
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=get_sm_num())
361362

362-
# ===== [recompute] o2 = swiglu(o1) =====
363+
# [recompute] o2 = swiglu(o1)
363364
o2 = swiglu(o1)
364365

365-
# ===== do2 = deep_gemm(do3_fp8, w2_fp8)
366+
# do2 = deep_gemm(do3_fp8, w2_fp8)
366367
do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
367368
do3, w2, return_mode="with_input_transpose_quant"
368369
)
369370

370-
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
371+
# dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
371372
o2 = FP8LinearFunctionBase.padding(o2, 0)
372373
o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
373374
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
@@ -397,15 +398,15 @@ def common_fp8_mlp_bwd(
397398
o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32
398399
)
399400

400-
# ===== do1 = swiglu_grad(o1, None, do2) =====
401+
# do1 = swiglu_grad(o1, None, do2)
401402
do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2)
402403

403-
# ===== dx = deep_gemm(do1_fp8, w1_fp8) =====
404+
# dx = deep_gemm(do1_fp8, w1_fp8)
404405
dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
405406
do1, w1, return_mode="with_input_transpose_quant"
406407
)
407408

408-
# ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) =====
409+
# dw1 = deep_gemm(x_t_fp8, do1_t_fp8)
409410
if apply_backward_hook:
410411
if WeightGradStore.enabled:
411412
WeightGradStore.put(
@@ -442,15 +443,15 @@ def fp8_mlp_fwd(x, w1, w2):
442443
x_orig_shape = x.shape
443444
x = x.reshape([-1, x_orig_shape[-1]])
444445

445-
# ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
446+
# o1 = deep_gemm(x_fp8, w1_t_fp8)
446447
o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear(
447448
x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant"
448449
)
449450

450-
# ===== o2 = swiglu(o1) =====
451+
# o2 = swiglu(o1)
451452
o2 = swiglu(o1)
452453

453-
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
454+
# o3 = deep_gemm(o2_fp8, w2_t_fp8)
454455
o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True)
455456

456457
if len(x_orig_shape) > 2:
@@ -460,9 +461,9 @@ def fp8_mlp_fwd(x, w1, w2):
460461

461462
@staticmethod
462463
def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2):
463-
# ===== compute norm_output =====
464+
# compute norm_output
464465
norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
465-
# ===== compute fp8_mlp_fwd =====
466+
# compute fp8_mlp_fwd
466467
_, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
467468
return o3
468469

@@ -510,10 +511,10 @@ def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False):
510511

511512
@staticmethod
512513
def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2):
513-
# ===== recompute norm_output =====
514+
# recompute norm_output
514515
norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
515516

516-
# ===== compute fp8_mlp_fwd =====
517+
# compute fp8_mlp_fwd
517518
d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True)
518519

519520
if hasattr(norm_w, "_apply_backward_hook"):
@@ -567,7 +568,7 @@ def backward(ctx, dout):
567568
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
568569
)
569570

570-
# ===== dx = deep_gemm(dout_fp8, w_fp8)
571+
# dx = deep_gemm(dout_fp8, w_fp8)
571572
dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
572573
dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant"
573574
)
@@ -576,15 +577,15 @@ def backward(ctx, dout):
576577
else:
577578
x_t_fp8, x_t_scale = x
578579

579-
# ===== dx = deep_gemm(dout_fp8, w_fp8)
580+
# dx = deep_gemm(dout_fp8, w_fp8)
580581
dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear(
581582
dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant"
582583
)
583584
dx_orig_shape = dout.shape[:-1]
584585
dx_orig_shape.append(ctx.x_t_shape[0])
585586
dx = dx.reshape(dx_orig_shape)
586587

587-
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
588+
# dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
588589
FP8LinearFunctionBase.compute_expert_w_grad(
589590
x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32
590591
)
@@ -668,20 +669,20 @@ def forward(self, x):
668669
class FusedNormFP8MLPFunction(paddle.autograd.PyLayer):
669670
@staticmethod
670671
def forward(ctx, x, norm_w, w1, w2, norm_eps):
671-
# ===== compute norm_output =====
672+
# compute norm_output
672673
norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps)
673-
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
674+
# reshape for deep_gemm, since deep_gemm only support 2D
674675
x_orig_shape = norm_output.shape
675676
norm_output = norm_output.reshape([-1, x_orig_shape[-1]])
676677

677-
# ===== call func fp8_mlp_fwd =====
678+
# call func fp8_mlp_fwd
678679
_, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2)
679680

680-
# ===== reshape to origin shape =====
681+
# reshape to origin shape
681682
if len(x_orig_shape) > 2:
682683
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
683684

684-
# ===== save for backward =====
685+
# save for backward
685686
ctx.save_for_backward(
686687
norm_output,
687688
invar,
@@ -696,27 +697,27 @@ def forward(ctx, x, norm_w, w1, w2, norm_eps):
696697

697698
@staticmethod
698699
def backward(ctx, do3):
699-
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
700+
# reshape for deep_gemm, since deep_gemm only support 2D
700701
do3_orig_shape = do3.shape
701702
do3 = do3.reshape([-1, do3_orig_shape[-1]])
702703

703-
# ===== recive saved tensors =====
704+
# recive saved tensors
704705
norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor()
705706

706707
x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
707708
norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True
708709
)
709710

710-
# ===== call func common_fp8_mlp_bwd =====
711+
# call func common_fp8_mlp_bwd
711712
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(
712713
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale
713714
)
714715

715-
# ===== reshape to origin shape =====
716+
# reshape to origin shape
716717
if len(x_orig_shape) > 2:
717718
d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]])
718719

719-
# ===== compute norm grad =====
720+
# compute norm grad
720721
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps)
721722

722723
return dx, d_rms_norm_weight, dw1, dw2
@@ -725,17 +726,17 @@ def backward(ctx, do3):
725726
class FP8MlpFunction(paddle.autograd.PyLayer):
726727
@staticmethod
727728
def forward(ctx, x, w1, w2, recompute_fwd_gate_up):
728-
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
729+
# reshape for deep_gemm, since deep_gemm only support 2D
729730
x_orig_shape = x.shape
730731
x = x.reshape([-1, x_orig_shape[-1]])
731732

732-
# ===== call func fp8_mlp_fwd =====
733+
# call func fp8_mlp_fwd
733734
o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2)
734-
# ===== reshape to origin shape =====
735+
# reshape to origin shape
735736
if len(x_orig_shape) > 2:
736737
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
737738

738-
# ===== save for backward =====
739+
# save for backward
739740
o1 = None if recompute_fwd_gate_up else o1
740741
ctx.save_for_backward(
741742
o1,
@@ -749,14 +750,14 @@ def forward(ctx, x, w1, w2, recompute_fwd_gate_up):
749750

750751
@staticmethod
751752
def backward(ctx, do3):
752-
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
753+
# reshape for deep_gemm, since deep_gemm only support 2D
753754
do3_orig_shape = do3.shape
754755
do3 = do3.reshape([-1, do3_orig_shape[-1]])
755756

756-
# ===== recive saved tensors =====
757+
# recive saved tensors
757758
o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor()
758759

759-
# ===== compute x_t_fp8, x_t_scale for dw1 =====
760+
# compute x_t_fp8, x_t_scale for dw1
760761
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
761762
x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0)
762763

@@ -768,7 +769,7 @@ def backward(ctx, do3):
768769
return_transpose_only=True,
769770
)
770771

771-
# ===== call func common_fp8_mlp_bwd =====
772+
# call func common_fp8_mlp_bwd
772773
if o1 is None:
773774
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
774775
do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True
@@ -777,7 +778,7 @@ def backward(ctx, do3):
777778
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(
778779
do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True
779780
)
780-
# ===== reshape to origin shape =====
781+
# reshape to origin shape
781782
if len(x_orig_shape) > 2:
782783
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
783784

0 commit comments

Comments
 (0)