@@ -215,31 +215,42 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
215
215
class LMHeadFunction (paddle .autograd .PyLayer ):
216
216
@staticmethod
217
217
def forward (ctx , x , weight , transpose_y ):
218
- out = paddle .matmul (x , weight , transpose_y = transpose_y )
218
+ out = paddle .matmul (x , weight , transpose_y = transpose_y )
219
219
220
- ctx .save_for_backward (x , weight , transpose_y )
220
+ ctx .save_for_backward (x , weight , transpose_y )
221
221
return out
222
222
223
223
@staticmethod
224
224
def backward (ctx , dout ):
225
225
if dout .dtype == paddle .float32 :
226
- dout = dout .cast ( paddle .bfloat16 )
226
+ dout = dout .cast (paddle .bfloat16 )
227
227
228
228
x , weight , transpose_y = ctx .saved_tensor ()
229
229
230
- dx = paddle .matmul ( dout , weight , transpose_y = not transpose_y )
230
+ dx = paddle .matmul (dout , weight , transpose_y = not transpose_y )
231
231
if transpose_y :
232
232
with paddle .amp .auto_cast (False ):
233
233
paddle ._C_ops .fused_linear_param_grad_add (
234
- dout .reshape ( [- 1 , dout .shape [- 1 ]]), x .reshape ( [- 1 , x .shape [- 1 ]]), weight .main_grad , None , True , False
235
- )
234
+ dout .reshape ([- 1 , dout .shape [- 1 ]]),
235
+ x .reshape ([- 1 , x .shape [- 1 ]]),
236
+ weight .main_grad ,
237
+ None ,
238
+ True ,
239
+ False ,
240
+ )
236
241
else :
237
242
with paddle .amp .auto_cast (False ):
238
243
paddle ._C_ops .fused_linear_param_grad_add (
239
- x .reshape ([- 1 , x .shape [- 1 ]]), dout .reshape ([- 1 , dout .shape [- 1 ]]), weight .main_grad , None , True , False
240
- )
244
+ x .reshape ([- 1 , x .shape [- 1 ]]),
245
+ dout .reshape ([- 1 , dout .shape [- 1 ]]),
246
+ weight .main_grad ,
247
+ None ,
248
+ True ,
249
+ False ,
250
+ )
241
251
return dx , None
242
252
253
+
243
254
def parallel_matmul (x : Tensor , y : Tensor , transpose_y = False , tensor_parallel_output = True ):
244
255
is_fleet_init = True
245
256
tensor_parallel_degree = 1
@@ -269,6 +280,7 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out
269
280
logits = LMHeadFunction .apply (x , y , transpose_y = transpose_y )
270
281
return logits
271
282
283
+
272
284
def scaled_dot_product_attention (
273
285
query_states ,
274
286
config ,
@@ -633,7 +645,9 @@ def _set_cos_sin_cache(self, seq_len):
633
645
dim = self .dim
634
646
635
647
freq_extra = 1.0 / (self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
636
- freq_inter = 1.0 / (self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
648
+ freq_inter = 1.0 / (
649
+ self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim )
650
+ )
637
651
638
652
low , high = yarn_find_correction_range (
639
653
self .beta_fast ,
@@ -1059,15 +1073,15 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
1059
1073
)
1060
1074
set_parameter_color ([self .shared_experts .w1 , self .shared_experts .w2 ], "shared_expert" )
1061
1075
1062
- def fp8_quant_weight (self , batch_mode = False ):
1076
+ def fp8_quant_weight (self , batch_mode = False , quant_transpose = True ):
1063
1077
"""Quantize weights in FP8 format.
1064
1078
1065
1079
Args:
1066
1080
batch_mode: If True, quantize all weights in batch mode using the first expert's weights.
1067
1081
If False, quantize each expert's weights individually.
1068
1082
"""
1069
1083
1070
- def quantize_weights (weight_list , weight_obj = None ):
1084
+ def quantize_weights (weight_list , weight_obj = None , quant_transpose = True ):
1071
1085
"""Helper function to quantize a list of weights."""
1072
1086
if weight_obj is None :
1073
1087
weight_obj = weight_list [0 ]
@@ -1081,31 +1095,32 @@ def quantize_weights(weight_list, weight_obj=None):
1081
1095
setattr (weight_obj , "fp8_weight_stacked" , fp8_weight )
1082
1096
setattr (weight_obj , "fp8_scale_stacked" , fp8_scale )
1083
1097
1084
- # Quantize with transpose
1085
- fp8_weight_t , fp8_scale_t = paddle .incubate .nn .functional .fused_stack_transpose_quant (
1086
- weight_list , transpose = True
1087
- )
1088
- setattr (weight_obj , "fp8_weight_stacked_transpose" , fp8_weight_t )
1089
- setattr (weight_obj , "fp8_scale_stacked_transpose" , fp8_scale_t )
1098
+ if quant_transpose :
1099
+ # Quantize with transpose
1100
+ fp8_weight_t , fp8_scale_t = paddle .incubate .nn .functional .fused_stack_transpose_quant (
1101
+ weight_list , transpose = True
1102
+ )
1103
+ setattr (weight_obj , "fp8_weight_stacked_transpose" , fp8_weight_t )
1104
+ setattr (weight_obj , "fp8_scale_stacked_transpose" , fp8_scale_t )
1090
1105
1091
1106
if batch_mode :
1092
1107
# Batch mode: process all experts' weights together
1093
1108
expert_w1_list = [expert .w1 for expert in self .experts if expert is not None ]
1094
1109
expert_w2_list = [expert .w2 for expert in self .experts if expert is not None ]
1095
1110
1096
1111
if expert_w1_list :
1097
- quantize_weights (expert_w1_list , expert_w1_list [0 ])
1112
+ quantize_weights (expert_w1_list , expert_w1_list [0 ], quant_transpose )
1098
1113
if expert_w2_list :
1099
- quantize_weights (expert_w2_list , expert_w2_list [0 ])
1114
+ quantize_weights (expert_w2_list , expert_w2_list [0 ], quant_transpose )
1100
1115
else :
1101
1116
# Individual mode: process each expert's weights separately
1102
1117
for expert in self .experts :
1103
1118
if expert is not None :
1104
- quantize_weights ([expert .w1 ])
1105
- quantize_weights ([expert .w1 ] )
1119
+ quantize_weights ([expert .w1 ], quant_transpose = quant_transpose )
1120
+ quantize_weights ([expert .w2 ], quant_transpose = quant_transpose )
1106
1121
1107
1122
if self .config .n_shared_experts is not None :
1108
- self .shared_experts .fp8_quant_weight ()
1123
+ self .shared_experts .fp8_quant_weight (quant_transpose )
1109
1124
1110
1125
def forward (self , hidden_states ):
1111
1126
if self .using_post_norm_recompute :
@@ -1762,9 +1777,9 @@ def __init__(
1762
1777
)
1763
1778
set_parameter_color ([self .q_up_weight , self .kv_up_weight ], "memory_attn" )
1764
1779
1765
- def fp8_quant_weight (self ):
1766
- cache_fp8_weight (self .q_up_weight )
1767
- cache_fp8_weight (self .kv_up_weight )
1780
+ def fp8_quant_weight (self , quant_transpose = True ):
1781
+ cache_fp8_weight (self .q_up_weight , quant_transpose = quant_transpose )
1782
+ cache_fp8_weight (self .kv_up_weight , quant_transpose = quant_transpose )
1768
1783
1769
1784
def forward (self , q_init , kv_init , position_ids ):
1770
1785
@@ -1890,8 +1905,8 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
1890
1905
self .eps = eps
1891
1906
set_parameter_color ([self .q_down_weight ], "rms_linear" )
1892
1907
1893
- def fp8_quant_weight (self ):
1894
- cache_fp8_weight (self .q_down_weight )
1908
+ def fp8_quant_weight (self , quant_transpose = True ):
1909
+ cache_fp8_weight (self .q_down_weight , quant_transpose = quant_transpose )
1895
1910
1896
1911
def forward (self , x ):
1897
1912
@@ -2053,12 +2068,12 @@ def linear_dtype_gaurd():
2053
2068
2054
2069
self .attn_func = scaled_dot_product_attention
2055
2070
2056
- def fp8_quant_weight (self ):
2071
+ def fp8_quant_weight (self , quant_transpose = True ):
2057
2072
2058
2073
if DSV3_USE_ATTEN_RECOMPUTE :
2059
- self .o_proj .fp8_quant_weight ()
2060
- self .memory_recompute_att .fp8_quant_weight ()
2061
- self .fused_rms_norm_linear .fp8_quant_weight ()
2074
+ self .o_proj .fp8_quant_weight (quant_transpose = quant_transpose )
2075
+ self .memory_recompute_att .fp8_quant_weight (quant_transpose = quant_transpose )
2076
+ self .fused_rms_norm_linear .fp8_quant_weight (quant_transpose = quant_transpose )
2062
2077
2063
2078
def _init_rope (self ):
2064
2079
if self .config .rope_scaling is None :
@@ -2279,16 +2294,16 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
2279
2294
else DeepseekV2MoE (config )
2280
2295
)
2281
2296
else :
2282
- self .mlp = DeepseekV2MLPClass (config )
2297
+ self .mlp = DeepseekV2MLPClass (config , recompute_fwd_gate_up = True )
2283
2298
2284
- def fp8_quant_weight (self , batch_mode = False ):
2299
+ def fp8_quant_weight (self , batch_mode = False , quant_transpose = True ):
2285
2300
"""fp8_quant_weight"""
2286
2301
if isinstance (self .mlp , DeepseekV2MoE ):
2287
2302
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
2288
- self .mlp .fp8_quant_weight (batch_mode )
2289
- self .self_attn .fp8_quant_weight ()
2303
+ self .mlp .fp8_quant_weight (batch_mode , quant_transpose = quant_transpose )
2304
+ self .self_attn .fp8_quant_weight (quant_transpose = quant_transpose )
2290
2305
elif isinstance (self .mlp , FP8Mlp ):
2291
- self .self_attn .fp8_quant_weight ()
2306
+ self .self_attn .fp8_quant_weight (quant_transpose = quant_transpose )
2292
2307
2293
2308
def forward (
2294
2309
self ,
@@ -2496,9 +2511,9 @@ def forward(
2496
2511
) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
2497
2512
hidden_states = self .hnorm (hidden_states )
2498
2513
nextn_hidden_state = self .enorm (nextn_hidden_state )
2499
-
2514
+
2500
2515
concat_h = paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 )
2501
- hidden_states = LMHeadFunction .apply ( concat_h , self .eh_proj .weight , False )
2516
+ hidden_states = LMHeadFunction .apply (concat_h , self .eh_proj .weight , False )
2502
2517
2503
2518
layer_outputs = super (DeepseekV2MTPLayer , self ).forward (
2504
2519
hidden_states ,
0 commit comments