Skip to content

Commit 45b9caa

Browse files
authored
opti lm head backward perf (#10975)
1 parent fb22a9f commit 45b9caa

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,34 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
212212
return assignment_list
213213

214214

215+
class LMHeadFunction(paddle.autograd.PyLayer):
216+
@staticmethod
217+
def forward(ctx, x, weight, transpose_y):
218+
out = paddle.matmul(x, weight, transpose_y = transpose_y)
219+
220+
ctx.save_for_backward(x, weight, transpose_y)
221+
return out
222+
223+
@staticmethod
224+
def backward(ctx, dout):
225+
if dout.dtype == paddle.float32:
226+
dout = dout.cast( paddle.bfloat16)
227+
228+
x, weight, transpose_y = ctx.saved_tensor()
229+
230+
dx = paddle.matmul( dout, weight, transpose_y = not transpose_y)
231+
if transpose_y:
232+
with paddle.amp.auto_cast(False):
233+
paddle._C_ops.fused_linear_param_grad_add(
234+
dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False
235+
)
236+
else:
237+
with paddle.amp.auto_cast(False):
238+
paddle._C_ops.fused_linear_param_grad_add(
239+
x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False
240+
)
241+
return dx, None
242+
215243
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
216244
is_fleet_init = True
217245
tensor_parallel_degree = 1
@@ -238,10 +266,9 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out
238266
return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)
239267

240268
else:
241-
logits = paddle.matmul(x, y, transpose_y=transpose_y)
269+
logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y)
242270
return logits
243271

244-
245272
def scaled_dot_product_attention(
246273
query_states,
247274
config,
@@ -2469,8 +2496,9 @@ def forward(
24692496
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
24702497
hidden_states = self.hnorm(hidden_states)
24712498
nextn_hidden_state = self.enorm(nextn_hidden_state)
2472-
2473-
hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))
2499+
2500+
concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1)
2501+
hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False)
24742502

24752503
layer_outputs = super(DeepseekV2MTPLayer, self).forward(
24762504
hidden_states,

0 commit comments

Comments
 (0)