@@ -1307,6 +1307,32 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
1307
1307
overlap_node = OverlapedScheduleChunk (forward_overlap_layers , backward_overlap_layers , use_fuion = DSV3_USE_FP8_GEMM )
1308
1308
return forward_pre_node , backward_pre_node , overlap_node , forward_post_node , backward_post_node
1309
1309
1310
+ class EmbeddingFunction (paddle .autograd .PyLayer ):
1311
+ @staticmethod
1312
+ def forward (ctx , x , weight ):
1313
+ out = paddle .nn .functional .embedding (
1314
+ x ,
1315
+ weight = weight ,
1316
+ padding_idx = None ,
1317
+ max_norm = None ,
1318
+ norm_type = 2.0 ,
1319
+ sparse = False ,
1320
+ scale_grad_by_freq = False )
1321
+
1322
+ ctx .save_for_backward (x , weight )
1323
+ return out
1324
+
1325
+ @staticmethod
1326
+ def backward (ctx , dout ):
1327
+ x , weight = ctx .saved_tensor ()
1328
+
1329
+ if hasattr ( weight , "main_grad" ):
1330
+ paddle .incubate .nn .functional .embedding_grad_add_to_ (x , weight .main_grad , dout )
1331
+ else :
1332
+ paddle .incubate .nn .functional .embedding_grad_add_to_ (x , weight .grad , dout )
1333
+
1334
+
1335
+ return None , None
1310
1336
1311
1337
class DeepseekV2EmbeddingPipe (nn .Layer ):
1312
1338
def __init__ (self , config : DeepseekV2Config ):
@@ -1337,7 +1363,7 @@ def forward(self, args):
1337
1363
_type_: _description_
1338
1364
"""
1339
1365
input_ids , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
1340
- inputs_embeds = self . embed_tokens ( input_ids )
1366
+ inputs_embeds = EmbeddingFunction . apply ( input_ids , self . embed_tokens . weight )
1341
1367
1342
1368
batch_size , seq_length = input_ids .shape
1343
1369
if self .config .num_nextn_predict_layers > 0 :
0 commit comments