@@ -212,6 +212,34 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
212
212
return assignment_list
213
213
214
214
215
+ class LMHeadFunction (paddle .autograd .PyLayer ):
216
+ @staticmethod
217
+ def forward (ctx , x , weight , transpose_y ):
218
+ out = paddle .matmul (x , weight , transpose_y = transpose_y )
219
+
220
+ ctx .save_for_backward (x , weight , transpose_y )
221
+ return out
222
+
223
+ @staticmethod
224
+ def backward (ctx , dout ):
225
+ if dout .dtype == paddle .float32 :
226
+ dout = dout .cast ( paddle .bfloat16 )
227
+
228
+ x , weight , transpose_y = ctx .saved_tensor ()
229
+
230
+ dx = paddle .matmul ( dout , weight , transpose_y = not transpose_y )
231
+ if transpose_y :
232
+ with paddle .amp .auto_cast (False ):
233
+ paddle ._C_ops .fused_linear_param_grad_add (
234
+ dout .reshape ( [- 1 , dout .shape [- 1 ]]), x .reshape ( [- 1 , x .shape [- 1 ]]), weight .main_grad , None , True , False
235
+ )
236
+ else :
237
+ with paddle .amp .auto_cast (False ):
238
+ paddle ._C_ops .fused_linear_param_grad_add (
239
+ x .reshape ([- 1 , x .shape [- 1 ]]), dout .reshape ([- 1 , dout .shape [- 1 ]]), weight .main_grad , None , True , False
240
+ )
241
+ return dx , None
242
+
215
243
def parallel_matmul (x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_output = True ):
216
244
is_fleet_init = True
217
245
tensor_parallel_degree = 1
@@ -238,10 +266,9 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out
238
266
return paddle .distributed .collective ._c_concat (logits , group = model_parallel_group )
239
267
240
268
else :
241
- logits = paddle . matmul (x , y , transpose_y = transpose_y )
269
+ logits = LMHeadFunction . apply (x , y , transpose_y = transpose_y )
242
270
return logits
243
271
244
-
245
272
def scaled_dot_product_attention (
246
273
query_states ,
247
274
config ,
@@ -2469,8 +2496,9 @@ def forward(
2469
2496
) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
2470
2497
hidden_states = self .hnorm (hidden_states )
2471
2498
nextn_hidden_state = self .enorm (nextn_hidden_state )
2472
-
2473
- hidden_states = self .eh_proj (paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 ))
2499
+
2500
+ concat_h = paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 )
2501
+ hidden_states = LMHeadFunction .apply ( concat_h , self .eh_proj .weight , False )
2474
2502
2475
2503
layer_outputs = super (DeepseekV2MTPLayer , self ).forward (
2476
2504
hidden_states ,
0 commit comments