Skip to content

Commit 69e9f96

Browse files
authored
optimizer embedding grad speed (#11011)
1 parent fe9e9f8 commit 69e9f96

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,32 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
13071307
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
13081308
return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node
13091309

1310+
class EmbeddingFunction(paddle.autograd.PyLayer):
1311+
@staticmethod
1312+
def forward(ctx, x, weight):
1313+
out = paddle.nn.functional.embedding(
1314+
x,
1315+
weight=weight,
1316+
padding_idx=None,
1317+
max_norm=None,
1318+
norm_type=2.0,
1319+
sparse=False,
1320+
scale_grad_by_freq=False )
1321+
1322+
ctx.save_for_backward(x, weight)
1323+
return out
1324+
1325+
@staticmethod
1326+
def backward(ctx, dout):
1327+
x, weight = ctx.saved_tensor()
1328+
1329+
if hasattr( weight, "main_grad" ):
1330+
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout)
1331+
else:
1332+
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout)
1333+
1334+
1335+
return None, None
13101336

13111337
class DeepseekV2EmbeddingPipe(nn.Layer):
13121338
def __init__(self, config: DeepseekV2Config):
@@ -1337,7 +1363,7 @@ def forward(self, args):
13371363
_type_: _description_
13381364
"""
13391365
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
1340-
inputs_embeds = self.embed_tokens(input_ids)
1366+
inputs_embeds = EmbeddingFunction.apply( input_ids, self.embed_tokens.weight )
13411367

13421368
batch_size, seq_length = input_ids.shape
13431369
if self.config.num_nextn_predict_layers > 0:

0 commit comments

Comments
 (0)