65
65
"DeepseekV2ForCausalLMPipe" ,
66
66
]
67
67
68
+ import queue
69
+
70
+ global_inputs_embeds_mtp_queue = queue .Queue ()
71
+
68
72
69
73
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
70
74
DSV3_USE_FP8_DISPATCH = os .getenv ("DSV3_USE_FP8_DISPATCH" , "False" ).lower () == "true"
@@ -1019,7 +1023,7 @@ def forward(self, args):
1019
1023
inputs_embeds = self .embed_tokens (input_ids )
1020
1024
1021
1025
batch_size , seq_length = input_ids .shape
1022
- if self .config .send_mtp_embed :
1026
+ if self .config .num_nextn_predict_layers > 0 :
1023
1027
seq_length -= self .config .num_nextn_predict_layers
1024
1028
1025
1029
if attention_mask is not None :
@@ -1042,7 +1046,7 @@ def forward(self, args):
1042
1046
attention_mask = paddle .tril (paddle .ones ((seq_length , seq_length ), dtype = "bool" ))
1043
1047
attention_mask .stop_gradient = True
1044
1048
1045
- if self .config .send_mtp_embed :
1049
+ if self .config .num_nextn_predict_layers > 0 :
1046
1050
inputs_embeds_extra = inputs_embeds [:, - self .config .num_nextn_predict_layers :, :] # [B, S, D]
1047
1051
inputs_embeds = inputs_embeds [:, : - self .config .num_nextn_predict_layers , :]
1048
1052
inputs_embeds_ori = inputs_embeds
@@ -1054,6 +1058,7 @@ def forward(self, args):
1054
1058
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
1055
1059
inputs_embeds = ScatterOp .apply (inputs_embeds )
1056
1060
embeds_res = [inputs_embeds ]
1061
+ mtp_embeds = []
1057
1062
for depth in range (self .config .num_nextn_predict_layers ):
1058
1063
inputs_embeds_mtp = paddle .concat (
1059
1064
[
@@ -1065,12 +1070,19 @@ def forward(self, args):
1065
1070
if self .sequence_parallel :
1066
1071
inputs_embeds_mtp = inputs_embeds_mtp .reshape ([- 1 , inputs_embeds_mtp .shape [- 1 ]])
1067
1072
inputs_embeds_mtp = ScatterOp .apply (inputs_embeds_mtp )
1068
- embeds_res .append (inputs_embeds_mtp )
1069
- # if not self.sequence_parallel
1070
- # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
1071
- # else:
1072
- # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
1073
- inputs_embeds = paddle .concat (embeds_res , axis = - 1 )
1073
+ mtp_embeds .append (inputs_embeds_mtp )
1074
+
1075
+ if self .config .send_mtp_embed :
1076
+ embeds_res .extend (mtp_embeds )
1077
+ # if not self.sequence_parallel
1078
+ # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
1079
+ # else:
1080
+ # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
1081
+ inputs_embeds = paddle .concat (embeds_res , axis = - 1 )
1082
+ else :
1083
+ global global_inputs_embeds_mtp_queue
1084
+ cloned_mtp_embeds = [t .detach () for t in mtp_embeds ]
1085
+ global_inputs_embeds_mtp_queue .put (cloned_mtp_embeds )
1074
1086
return return_args (inputs_embeds , attention_mask , attn_mask_startend_row_indices , position_ids )
1075
1087
else :
1076
1088
if self .sequence_parallel :
@@ -1359,9 +1371,15 @@ class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
1359
1371
def forward (self , args ):
1360
1372
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
1361
1373
1362
- hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 , axis = - 1 )
1363
- hidden_states_main_model = hidden_states_list [0 ]
1364
- inputs_embeds_cur_depth_list = hidden_states_list [1 :]
1374
+ if self .config .send_mtp_embed :
1375
+ hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 , axis = - 1 )
1376
+ hidden_states_main_model = hidden_states_list [0 ]
1377
+ inputs_embeds_cur_depth_list = hidden_states_list [1 :]
1378
+ else :
1379
+ hidden_states_main_model = hidden_states
1380
+ global global_inputs_embeds_mtp_queue
1381
+ inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue .get ()
1382
+
1365
1383
has_gradient = not hidden_states_main_model .stop_gradient
1366
1384
1367
1385
if attention_mask is not None and attention_mask .dtype == paddle .int32 :
@@ -1426,7 +1444,7 @@ def __init__(self, config):
1426
1444
def forward (self , args ):
1427
1445
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
1428
1446
1429
- if self .config .send_mtp_embed :
1447
+ if self .config .num_nextn_predict_layers > 0 :
1430
1448
hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 , axis = - 1 )
1431
1449
hidden_states = hidden_states_list [0 ]
1432
1450
hidden_states_mtp = hidden_states_list [- self .config .num_nextn_predict_layers :]
@@ -1451,7 +1469,7 @@ def embedding_weight(self):
1451
1469
return get_attr (self , "weight" )
1452
1470
1453
1471
def forward (self , args : Union [Tuple , paddle .Tensor ]):
1454
- if self .config .send_mtp_embed :
1472
+ if self .config .num_nextn_predict_layers > 0 :
1455
1473
logits = []
1456
1474
for _hidden_states in args :
1457
1475
logits .append (super ().forward (_hidden_states ))
@@ -1466,7 +1484,7 @@ def build_schedule_node(self):
1466
1484
1467
1485
class DeepseekV2PretrainingCriterionPipe (DeepseekV2PretrainingCriterion ):
1468
1486
def forward (self , logits , labels ):
1469
- if self .config .send_mtp_embed :
1487
+ if self .config .num_nextn_predict_layers > 0 :
1470
1488
mtp_logits = logits [1 :]
1471
1489
logits = logits [0 ]
1472
1490
loss = super ().forward (logits , labels , mtp_logits = mtp_logits )
@@ -1669,6 +1687,19 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
1669
1687
# DON'T init PipelinePretrainedModel
1670
1688
# PipelinePretrainedModel.__init__(self.super(), config=config)
1671
1689
1690
+ def fp8_quant_weight (self , batch_mode = False ):
1691
+ """fp8_quant_weight"""
1692
+ with paddle .no_grad ():
1693
+ for i , layer in self ._sub_layers .items ():
1694
+ if isinstance (
1695
+ layer , paddle .distributed .fleet .meta_parallel .parallel_layers .pp_layers .PipelineLayerChunk
1696
+ ):
1697
+ for i , sub_layer in layer .named_sublayers ():
1698
+ if isinstance (sub_layer , DeepseekV2DecoderLayer ) and hasattr (sub_layer , "fp8_quant_weight" ):
1699
+ sub_layer .fp8_quant_weight (batch_mode )
1700
+ if isinstance (layer , DeepseekV2DecoderLayer ) and hasattr (layer , "fp8_quant_weight" ):
1701
+ layer .fp8_quant_weight (batch_mode )
1702
+
1672
1703
def get_loss_fn (self , config ):
1673
1704
return DeepseekV2PretrainingCriterionPipe (config )
1674
1705
0 commit comments