Skip to content

Commit ef47c11

Browse files
authored
Support magic send for mtp (#10916)
* support fp8 weight quant cache * fix bug * fix confilct * support magic send
1 parent 87fa744 commit ef47c11

File tree

5 files changed

+159
-24
lines changed

5 files changed

+159
-24
lines changed

llm/run_pretrain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
print_rank_0,
2828
)
2929
from paddlenlp.trainer import (
30+
FP8QuantWeightCallback,
3031
PdArgumentParser,
3132
StepFlexToken,
3233
Trainer,
@@ -568,7 +569,7 @@ def main():
568569
* data_args.max_seq_length
569570
)
570571

571-
callbacks = [StepFlexToken()]
572+
callbacks = [StepFlexToken(), FP8QuantWeightCallback()]
572573

573574
trainer = PretrainingTrainer(
574575
model=model,

paddlenlp/trainer/trainer_callback.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121
import dataclasses
2222
import json
23+
import os
2324
from dataclasses import dataclass
2425
from typing import Dict, List, Optional, Union
2526

@@ -41,6 +42,7 @@
4142
"PrinterCallback",
4243
"EarlyStoppingCallback",
4344
"StepFlexToken",
45+
"FP8QuantWeightCallback",
4446
]
4547

4648

@@ -615,3 +617,34 @@ def on_step_begin(
615617
model = kwargs.pop("model")
616618
if hasattr(model, "step_flex_token"):
617619
model.step_flex_token(state.global_step)
620+
621+
622+
g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0))
623+
624+
625+
def enable_in_dict_config(config, key):
626+
"""enable_in_dict_config"""
627+
return key in config and config[key]
628+
629+
630+
skip_count = 0
631+
632+
633+
class FP8QuantWeightCallback(TrainerCallback):
634+
"""
635+
FP8QuantWeightCallback
636+
"""
637+
638+
def on_step_begin(self, args, state, control, **kwargs):
639+
"""
640+
每个step开始前把专家参数quant成fp8q
641+
"""
642+
model = kwargs["model"]
643+
optimizer = kwargs["optimizer"]
644+
global skip_count
645+
646+
if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
647+
model.fp8_quant_weight(True)
648+
optimizer.clear_param_storage("moe_expert")
649+
650+
skip_count += 1

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,49 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
981981
config=config, intermediate_size=intermediate_size, is_moe=False
982982
)
983983

984+
def fp8_quant_weight(self, batch_mode=False):
985+
"""Quantize weights in FP8 format.
986+
987+
Args:
988+
batch_mode: If True, quantize all weights in batch mode using the first expert's weights.
989+
If False, quantize each expert's weights individually.
990+
"""
991+
992+
def quantize_weights(weight_list, weight_obj=None):
993+
"""Helper function to quantize a list of weights."""
994+
if weight_obj is None:
995+
weight_obj = weight_list[0]
996+
997+
# Quantize without transpose
998+
fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
999+
weight_list, transpose=False
1000+
)
1001+
setattr(weight_obj, "fp8_weight_stacked", fp8_weight)
1002+
setattr(weight_obj, "fp8_scale_stacked", fp8_scale)
1003+
1004+
# Quantize with transpose
1005+
fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1006+
weight_list, transpose=True
1007+
)
1008+
setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t)
1009+
setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t)
1010+
1011+
if batch_mode:
1012+
# Batch mode: process all experts' weights together
1013+
expert_w1_list = [expert.w1 for expert in self.experts if expert is not None]
1014+
expert_w2_list = [expert.w2 for expert in self.experts if expert is not None]
1015+
1016+
if expert_w1_list:
1017+
quantize_weights(expert_w1_list, expert_w1_list[0])
1018+
if expert_w2_list:
1019+
quantize_weights(expert_w2_list, expert_w2_list[0])
1020+
else:
1021+
# Individual mode: process each expert's weights separately
1022+
for expert in self.experts:
1023+
if expert is not None:
1024+
quantize_weights([expert.w1])
1025+
quantize_weights([expert.w1])
1026+
9841027
def forward(self, hidden_states):
9851028
if self.using_post_norm_recompute:
9861029
super().update_flex_token()
@@ -1928,6 +1971,12 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
19281971
else:
19291972
self.mlp = DeepseekV2MLPClass(config)
19301973

1974+
def fp8_quant_weight(self, batch_mode=False):
1975+
"""fp8_quant_weight"""
1976+
if isinstance(self.mlp, DeepseekV2MoE):
1977+
logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
1978+
self.mlp.fp8_quant_weight(batch_mode)
1979+
19311980
def forward(
19321981
self,
19331982
hidden_states: paddle.Tensor,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
"DeepseekV2ForCausalLMPipe",
6666
]
6767

68+
import queue
69+
70+
global_inputs_embeds_mtp_queue = queue.Queue()
71+
6872

6973
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
7074
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"
@@ -1019,7 +1023,7 @@ def forward(self, args):
10191023
inputs_embeds = self.embed_tokens(input_ids)
10201024

10211025
batch_size, seq_length = input_ids.shape
1022-
if self.config.send_mtp_embed:
1026+
if self.config.num_nextn_predict_layers > 0:
10231027
seq_length -= self.config.num_nextn_predict_layers
10241028

10251029
if attention_mask is not None:
@@ -1042,7 +1046,7 @@ def forward(self, args):
10421046
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
10431047
attention_mask.stop_gradient = True
10441048

1045-
if self.config.send_mtp_embed:
1049+
if self.config.num_nextn_predict_layers > 0:
10461050
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D]
10471051
inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
10481052
inputs_embeds_ori = inputs_embeds
@@ -1054,6 +1058,7 @@ def forward(self, args):
10541058
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
10551059
inputs_embeds = ScatterOp.apply(inputs_embeds)
10561060
embeds_res = [inputs_embeds]
1061+
mtp_embeds = []
10571062
for depth in range(self.config.num_nextn_predict_layers):
10581063
inputs_embeds_mtp = paddle.concat(
10591064
[
@@ -1065,12 +1070,19 @@ def forward(self, args):
10651070
if self.sequence_parallel:
10661071
inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]])
10671072
inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp)
1068-
embeds_res.append(inputs_embeds_mtp)
1069-
# if not self.sequence_parallel
1070-
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
1071-
# else:
1072-
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
1073-
inputs_embeds = paddle.concat(embeds_res, axis=-1)
1073+
mtp_embeds.append(inputs_embeds_mtp)
1074+
1075+
if self.config.send_mtp_embed:
1076+
embeds_res.extend(mtp_embeds)
1077+
# if not self.sequence_parallel
1078+
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
1079+
# else:
1080+
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
1081+
inputs_embeds = paddle.concat(embeds_res, axis=-1)
1082+
else:
1083+
global global_inputs_embeds_mtp_queue
1084+
cloned_mtp_embeds = [t.detach() for t in mtp_embeds]
1085+
global_inputs_embeds_mtp_queue.put(cloned_mtp_embeds)
10741086
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
10751087
else:
10761088
if self.sequence_parallel:
@@ -1359,9 +1371,15 @@ class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
13591371
def forward(self, args):
13601372
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
13611373

1362-
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
1363-
hidden_states_main_model = hidden_states_list[0]
1364-
inputs_embeds_cur_depth_list = hidden_states_list[1:]
1374+
if self.config.send_mtp_embed:
1375+
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
1376+
hidden_states_main_model = hidden_states_list[0]
1377+
inputs_embeds_cur_depth_list = hidden_states_list[1:]
1378+
else:
1379+
hidden_states_main_model = hidden_states
1380+
global global_inputs_embeds_mtp_queue
1381+
inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get()
1382+
13651383
has_gradient = not hidden_states_main_model.stop_gradient
13661384

13671385
if attention_mask is not None and attention_mask.dtype == paddle.int32:
@@ -1426,7 +1444,7 @@ def __init__(self, config):
14261444
def forward(self, args):
14271445
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
14281446

1429-
if self.config.send_mtp_embed:
1447+
if self.config.num_nextn_predict_layers > 0:
14301448
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
14311449
hidden_states = hidden_states_list[0]
14321450
hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :]
@@ -1451,7 +1469,7 @@ def embedding_weight(self):
14511469
return get_attr(self, "weight")
14521470

14531471
def forward(self, args: Union[Tuple, paddle.Tensor]):
1454-
if self.config.send_mtp_embed:
1472+
if self.config.num_nextn_predict_layers > 0:
14551473
logits = []
14561474
for _hidden_states in args:
14571475
logits.append(super().forward(_hidden_states))
@@ -1466,7 +1484,7 @@ def build_schedule_node(self):
14661484

14671485
class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion):
14681486
def forward(self, logits, labels):
1469-
if self.config.send_mtp_embed:
1487+
if self.config.num_nextn_predict_layers > 0:
14701488
mtp_logits = logits[1:]
14711489
logits = logits[0]
14721490
loss = super().forward(logits, labels, mtp_logits=mtp_logits)
@@ -1669,6 +1687,19 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
16691687
# DON'T init PipelinePretrainedModel
16701688
# PipelinePretrainedModel.__init__(self.super(), config=config)
16711689

1690+
def fp8_quant_weight(self, batch_mode=False):
1691+
"""fp8_quant_weight"""
1692+
with paddle.no_grad():
1693+
for i, layer in self._sub_layers.items():
1694+
if isinstance(
1695+
layer, paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers.PipelineLayerChunk
1696+
):
1697+
for i, sub_layer in layer.named_sublayers():
1698+
if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"):
1699+
sub_layer.fp8_quant_weight(batch_mode)
1700+
if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"):
1701+
layer.fp8_quant_weight(batch_mode)
1702+
16721703
def get_loss_fn(self, config):
16731704
return DeepseekV2PretrainingCriterionPipe(config)
16741705

paddlenlp/transformers/fp8_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,29 @@ def swiglu(x, y=None):
4646
]
4747

4848

49+
def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False):
50+
"""_get_fp8_weight_and_scale"""
51+
if stacked:
52+
if transpose:
53+
fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose
54+
else:
55+
fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked
56+
else:
57+
if transpose:
58+
fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose
59+
else:
60+
fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale
61+
return fp8_weight, fp8_scale
62+
63+
64+
def fused_stack_quant(expert_weight_list, transpose=False):
65+
if hasattr(expert_weight_list[0], "fp8_weight_stacked"):
66+
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=transpose)
67+
else:
68+
w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose)
69+
return w, scale
70+
71+
4972
class FP8LinearFunctionBase:
5073
@staticmethod
5174
def dequantize_fp8_to_fp32(fp8_tensor, scale):
@@ -524,7 +547,9 @@ def backward(ctx, do3):
524547
)
525548

526549
# ===== call func common_fp8_mlp_bwd =====
527-
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2)
550+
d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(
551+
do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2
552+
)
528553

529554
# ===== reshape to origin shape =====
530555
if len(x_orig_shape) > 2:
@@ -723,7 +748,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert):
723748
if not self.is_split_group_gemm:
724749
self.m_indices = self.gen_m_indices(tokens_per_expert)
725750
# concat w1, shape is [num_groups, n, k]
726-
w1_t_quant, w1_t_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w1, transpose=True)
751+
w1_t_quant, w1_t_scale = fused_stack_quant(expert_w1, transpose=True)
727752
w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]])
728753
w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]])
729754

@@ -765,7 +790,7 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
765790
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
766791
"""
767792
# concat and transpose w2
768-
w2_quant, w2_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2, transpose=True)
793+
w2_quant, w2_scale = fused_stack_quant(expert_w2, transpose=True)
769794
w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]])
770795
w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]])
771796

@@ -801,9 +826,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
801826
[m_sum, n] = [m_sum, k] * [num_groups, k, n]
802827
"""
803828
# recompute concated_w2_2d
804-
bw_w2_quant, bw_w2_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
805-
expert_w2, transpose=False
806-
)
829+
bw_w2_quant, bw_w2_scale = fused_stack_quant(expert_w2, transpose=False)
807830
bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]])
808831
bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]])
809832

@@ -849,9 +872,7 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
849872
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
850873
"""
851874
# recompute concated_w1_t
852-
bw_w1_quant, bw_w1_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
853-
expert_w1, transpose=False
854-
)
875+
bw_w1_quant, bw_w1_scale = fused_stack_quant(expert_w1, transpose=False)
855876
bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]])
856877
bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]])
857878

0 commit comments

Comments
 (0)