Skip to content

Commit 1ab02cd

Browse files
authored
quant transpost disable (#10987)
1 parent fb6d8af commit 1ab02cd

File tree

4 files changed

+105
-73
lines changed

4 files changed

+105
-73
lines changed

paddlenlp/trainer/trainer_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def on_step_begin(self, args, state, control, **kwargs):
645645
global skip_count
646646

647647
if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
648-
model.fp8_quant_weight(True)
648+
model.fp8_quant_weight(True, quant_transpose=False)
649649
optimizer.clear_param_storage("moe_expert")
650650
optimizer.clear_param_storage("rms_linear")
651651
optimizer.clear_param_storage("memory_attn")

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -215,31 +215,42 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
215215
class LMHeadFunction(paddle.autograd.PyLayer):
216216
@staticmethod
217217
def forward(ctx, x, weight, transpose_y):
218-
out = paddle.matmul(x, weight, transpose_y = transpose_y)
218+
out = paddle.matmul(x, weight, transpose_y=transpose_y)
219219

220-
ctx.save_for_backward(x, weight, transpose_y)
220+
ctx.save_for_backward(x, weight, transpose_y)
221221
return out
222222

223223
@staticmethod
224224
def backward(ctx, dout):
225225
if dout.dtype == paddle.float32:
226-
dout = dout.cast( paddle.bfloat16)
226+
dout = dout.cast(paddle.bfloat16)
227227

228228
x, weight, transpose_y = ctx.saved_tensor()
229229

230-
dx = paddle.matmul( dout, weight, transpose_y = not transpose_y)
230+
dx = paddle.matmul(dout, weight, transpose_y=not transpose_y)
231231
if transpose_y:
232232
with paddle.amp.auto_cast(False):
233233
paddle._C_ops.fused_linear_param_grad_add(
234-
dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False
235-
)
234+
dout.reshape([-1, dout.shape[-1]]),
235+
x.reshape([-1, x.shape[-1]]),
236+
weight.main_grad,
237+
None,
238+
True,
239+
False,
240+
)
236241
else:
237242
with paddle.amp.auto_cast(False):
238243
paddle._C_ops.fused_linear_param_grad_add(
239-
x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False
240-
)
244+
x.reshape([-1, x.shape[-1]]),
245+
dout.reshape([-1, dout.shape[-1]]),
246+
weight.main_grad,
247+
None,
248+
True,
249+
False,
250+
)
241251
return dx, None
242252

253+
243254
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
244255
is_fleet_init = True
245256
tensor_parallel_degree = 1
@@ -269,6 +280,7 @@ def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_out
269280
logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y)
270281
return logits
271282

283+
272284
def scaled_dot_product_attention(
273285
query_states,
274286
config,
@@ -633,7 +645,9 @@ def _set_cos_sin_cache(self, seq_len):
633645
dim = self.dim
634646

635647
freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
636-
freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
648+
freq_inter = 1.0 / (
649+
self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)
650+
)
637651

638652
low, high = yarn_find_correction_range(
639653
self.beta_fast,
@@ -1059,15 +1073,15 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
10591073
)
10601074
set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert")
10611075

1062-
def fp8_quant_weight(self, batch_mode=False):
1076+
def fp8_quant_weight(self, batch_mode=False, quant_transpose=True):
10631077
"""Quantize weights in FP8 format.
10641078
10651079
Args:
10661080
batch_mode: If True, quantize all weights in batch mode using the first expert's weights.
10671081
If False, quantize each expert's weights individually.
10681082
"""
10691083

1070-
def quantize_weights(weight_list, weight_obj=None):
1084+
def quantize_weights(weight_list, weight_obj=None, quant_transpose=True):
10711085
"""Helper function to quantize a list of weights."""
10721086
if weight_obj is None:
10731087
weight_obj = weight_list[0]
@@ -1081,31 +1095,32 @@ def quantize_weights(weight_list, weight_obj=None):
10811095
setattr(weight_obj, "fp8_weight_stacked", fp8_weight)
10821096
setattr(weight_obj, "fp8_scale_stacked", fp8_scale)
10831097

1084-
# Quantize with transpose
1085-
fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1086-
weight_list, transpose=True
1087-
)
1088-
setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t)
1089-
setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t)
1098+
if quant_transpose:
1099+
# Quantize with transpose
1100+
fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant(
1101+
weight_list, transpose=True
1102+
)
1103+
setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t)
1104+
setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t)
10901105

10911106
if batch_mode:
10921107
# Batch mode: process all experts' weights together
10931108
expert_w1_list = [expert.w1 for expert in self.experts if expert is not None]
10941109
expert_w2_list = [expert.w2 for expert in self.experts if expert is not None]
10951110

10961111
if expert_w1_list:
1097-
quantize_weights(expert_w1_list, expert_w1_list[0])
1112+
quantize_weights(expert_w1_list, expert_w1_list[0], quant_transpose)
10981113
if expert_w2_list:
1099-
quantize_weights(expert_w2_list, expert_w2_list[0])
1114+
quantize_weights(expert_w2_list, expert_w2_list[0], quant_transpose)
11001115
else:
11011116
# Individual mode: process each expert's weights separately
11021117
for expert in self.experts:
11031118
if expert is not None:
1104-
quantize_weights([expert.w1])
1105-
quantize_weights([expert.w1])
1119+
quantize_weights([expert.w1], quant_transpose=quant_transpose)
1120+
quantize_weights([expert.w2], quant_transpose=quant_transpose)
11061121

11071122
if self.config.n_shared_experts is not None:
1108-
self.shared_experts.fp8_quant_weight()
1123+
self.shared_experts.fp8_quant_weight(quant_transpose)
11091124

11101125
def forward(self, hidden_states):
11111126
if self.using_post_norm_recompute:
@@ -1762,9 +1777,9 @@ def __init__(
17621777
)
17631778
set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn")
17641779

1765-
def fp8_quant_weight(self):
1766-
cache_fp8_weight(self.q_up_weight)
1767-
cache_fp8_weight(self.kv_up_weight)
1780+
def fp8_quant_weight(self, quant_transpose=True):
1781+
cache_fp8_weight(self.q_up_weight, quant_transpose=quant_transpose)
1782+
cache_fp8_weight(self.kv_up_weight, quant_transpose=quant_transpose)
17681783

17691784
def forward(self, q_init, kv_init, position_ids):
17701785

@@ -1890,8 +1905,8 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
18901905
self.eps = eps
18911906
set_parameter_color([self.q_down_weight], "rms_linear")
18921907

1893-
def fp8_quant_weight(self):
1894-
cache_fp8_weight(self.q_down_weight)
1908+
def fp8_quant_weight(self, quant_transpose=True):
1909+
cache_fp8_weight(self.q_down_weight, quant_transpose=quant_transpose)
18951910

18961911
def forward(self, x):
18971912

@@ -2053,12 +2068,12 @@ def linear_dtype_gaurd():
20532068

20542069
self.attn_func = scaled_dot_product_attention
20552070

2056-
def fp8_quant_weight(self):
2071+
def fp8_quant_weight(self, quant_transpose=True):
20572072

20582073
if DSV3_USE_ATTEN_RECOMPUTE:
2059-
self.o_proj.fp8_quant_weight()
2060-
self.memory_recompute_att.fp8_quant_weight()
2061-
self.fused_rms_norm_linear.fp8_quant_weight()
2074+
self.o_proj.fp8_quant_weight(quant_transpose=quant_transpose)
2075+
self.memory_recompute_att.fp8_quant_weight(quant_transpose=quant_transpose)
2076+
self.fused_rms_norm_linear.fp8_quant_weight(quant_transpose=quant_transpose)
20622077

20632078
def _init_rope(self):
20642079
if self.config.rope_scaling is None:
@@ -2279,16 +2294,16 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
22792294
else DeepseekV2MoE(config)
22802295
)
22812296
else:
2282-
self.mlp = DeepseekV2MLPClass(config)
2297+
self.mlp = DeepseekV2MLPClass(config, recompute_fwd_gate_up=True)
22832298

2284-
def fp8_quant_weight(self, batch_mode=False):
2299+
def fp8_quant_weight(self, batch_mode=False, quant_transpose=True):
22852300
"""fp8_quant_weight"""
22862301
if isinstance(self.mlp, DeepseekV2MoE):
22872302
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
2288-
self.mlp.fp8_quant_weight(batch_mode)
2289-
self.self_attn.fp8_quant_weight()
2303+
self.mlp.fp8_quant_weight(batch_mode, quant_transpose=quant_transpose)
2304+
self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose)
22902305
elif isinstance(self.mlp, FP8Mlp):
2291-
self.self_attn.fp8_quant_weight()
2306+
self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose)
22922307

22932308
def forward(
22942309
self,
@@ -2496,9 +2511,9 @@ def forward(
24962511
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
24972512
hidden_states = self.hnorm(hidden_states)
24982513
nextn_hidden_state = self.enorm(nextn_hidden_state)
2499-
2514+
25002515
concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1)
2501-
hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False)
2516+
hidden_states = LMHeadFunction.apply(concat_h, self.eh_proj.weight, False)
25022517

25032518
layer_outputs = super(DeepseekV2MTPLayer, self).forward(
25042519
hidden_states,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,9 +1164,7 @@ def __init__(self, forward_node, backward_node, name=""):
11641164
assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance(
11651165
backward_node, FusionFp8DecoderLayerNode
11661166
)
1167-
assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance(
1168-
backward_node, DenseDecoderLayerNode
1169-
)
1167+
assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance(backward_node, DenseDecoderLayerNode)
11701168
self.forward_node = forward_node
11711169
self.backward_node = backward_node
11721170
self.name = name
@@ -1231,9 +1229,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12311229
paddle.base.core.nvprof_nvtx_pop() # moe_mlp
12321230

12331231
paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine")
1234-
inputs = self.forward_node.combine_forward(
1235-
inputs, async_finish=True, allocate_on_comm_stream=True
1236-
)
1232+
inputs = self.forward_node.combine_forward(inputs, async_finish=True, allocate_on_comm_stream=True)
12371233
combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
12381234
output_grad = self.backward_node.attn_node.backward(output_grad)
12391235
combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id)
@@ -1252,7 +1248,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12521248
def build_overlapped_nodes(forward_chunk, backward_chunk):
12531249
overlap_element_class = (
12541250
FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode,
1255-
DenseDecoderLayerNode
1251+
DenseDecoderLayerNode,
12561252
)
12571253
forward_decoder_layer_num = 0
12581254
backward_decoder_layer_num = 0
@@ -1840,11 +1836,7 @@ def attn_compute_for_fusion(self, args):
18401836
def build_schedule_node(self):
18411837
if isinstance(self.mlp, DeepseekV2MoE):
18421838
self.mlp.update_flex_token()
1843-
if (
1844-
self.mlp.using_flex_token and
1845-
DSV3_USE_FP8_GEMM and
1846-
self.config.num_nextn_predict_layers == 1
1847-
):
1839+
if self.mlp.using_flex_token and DSV3_USE_FP8_GEMM and self.config.num_nextn_predict_layers == 1:
18481840
prev_send_mtp_embed = self.config.send_mtp_embed
18491841
self.config.send_mtp_embed = True # must be True in MTP node
18501842

@@ -2108,7 +2100,7 @@ def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, reco
21082100
# DON'T init PipelinePretrainedModel
21092101
# PipelinePretrainedModel.__init__(self.super(), config=config)
21102102

2111-
def fp8_quant_weight(self, batch_mode=False):
2103+
def fp8_quant_weight(self, batch_mode=False, quant_transpose=True):
21122104
"""fp8_quant_weight"""
21132105
with paddle.no_grad():
21142106
for i, layer in self._sub_layers.items():
@@ -2117,9 +2109,9 @@ def fp8_quant_weight(self, batch_mode=False):
21172109
):
21182110
for i, sub_layer in layer.named_sublayers():
21192111
if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"):
2120-
sub_layer.fp8_quant_weight(batch_mode)
2112+
sub_layer.fp8_quant_weight(batch_mode, quant_transpose)
21212113
if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"):
2122-
layer.fp8_quant_weight(batch_mode)
2114+
layer.fp8_quant_weight(batch_mode, quant_transpose)
21232115

21242116
def get_loss_fn(self, config):
21252117
return DeepseekV2PretrainingCriterionPipe(config)

paddlenlp/transformers/fp8_utils.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,12 @@ def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False):
8686

8787

8888
def fused_stack_quant(expert_weight_list, transpose=False):
89-
if hasattr(expert_weight_list[0], "fp8_weight_stacked"):
90-
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=transpose)
89+
if transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked"):
90+
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False)
91+
elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"):
92+
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True)
93+
elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked"):
94+
w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False)
9195
else:
9296
w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose)
9397
return w, scale
@@ -97,6 +101,8 @@ def weight_quant(weight, transpose=False):
97101
if transpose:
98102
if hasattr(weight, "fp8_weight_transpose"):
99103
return weight.fp8_weight_transpose, weight.fp8_scale_transpose
104+
elif hasattr(weight, "fp8_weight"):
105+
return weight.fp8_weight.T.contiguous(), weight.fp8_scale.T.contiguous()
100106
else:
101107
return paddle.incubate.nn.functional.fp8_quant_blockwise(
102108
weight,
@@ -590,21 +596,32 @@ def forward(self, x):
590596
return FP8LinearFunction.apply(x, self, keep_x=False)
591597

592598

593-
def cache_fp8_weight(weight):
599+
def cache_fp8_weight(weight, quant_transpose=True):
594600
if hasattr(weight, "fp8_weight"):
595601
return
596-
w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
597-
weight,
598-
output_scale_transpose=False,
599-
quant_method="128x128",
600-
input_transpose=True,
601-
return_transpose_only=False,
602-
)
603602

604-
setattr(weight, "fp8_weight_transpose", w_t_fp8)
605-
setattr(weight, "fp8_scale_transpose", w_t_scale)
606-
setattr(weight, "fp8_weight", w_fp8)
607-
setattr(weight, "fp8_scale", w_scale)
603+
if quant_transpose:
604+
w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
605+
weight,
606+
output_scale_transpose=False,
607+
quant_method="128x128",
608+
input_transpose=True,
609+
return_transpose_only=False,
610+
)
611+
setattr(weight, "fp8_weight_transpose", w_t_fp8)
612+
setattr(weight, "fp8_scale_transpose", w_t_scale)
613+
setattr(weight, "fp8_weight", w_fp8)
614+
setattr(weight, "fp8_scale", w_scale)
615+
else:
616+
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
617+
weight,
618+
output_scale_transpose=False,
619+
quant_method="128x128",
620+
input_transpose=False,
621+
return_transpose_only=False,
622+
)
623+
setattr(weight, "fp8_weight", w_fp8)
624+
setattr(weight, "fp8_scale", w_scale)
608625

609626

610627
class FP8KeepXLinear(paddle.nn.Layer):
@@ -619,8 +636,8 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
619636
)
620637
set_parameter_color([self.weight], "attn_out_project")
621638

622-
def fp8_quant_weight(self):
623-
cache_fp8_weight(self.weight)
639+
def fp8_quant_weight(self, quant_transpose=True):
640+
cache_fp8_weight(self.weight, quant_transpose=quant_transpose)
624641

625642
def forward(self, x):
626643
return FP8LinearFunction.apply(x, self, keep_x=True)
@@ -781,9 +798,9 @@ def __init__(
781798
is_bias=False,
782799
)
783800

784-
def fp8_quant_weight(self):
785-
cache_fp8_weight(self.w1)
786-
cache_fp8_weight(self.w2)
801+
def fp8_quant_weight(self, quant_transpose=True):
802+
cache_fp8_weight(self.w1, quant_transpose)
803+
cache_fp8_weight(self.w2, quant_transpose)
787804

788805
def forward(self, x):
789806
if self.using_post_norm_recompute:
@@ -865,6 +882,10 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=Non
865882
w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]])
866883
w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]])
867884

885+
if hasattr(expert_w1[0], "fp8_weight_stacked") and not hasattr(expert_w1[0], "fp8_weight_stacked_transpose"):
886+
w1_t_quant = w1_t_quant.contiguous().transpose([0, 2, 1]).contiguous()
887+
w1_t_scale = w1_t_scale.contiguous().transpose([0, 2, 1]).contiguous()
888+
868889
if x is None:
869890
x_fp8, x_scale = self.input_fp8, self.input_scale
870891
assert x_fp8 is not None and x_scale is not None
@@ -914,6 +935,10 @@ def fwd_down(
914935
w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]])
915936
w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]])
916937

938+
if hasattr(expert_w2[0], "fp8_weight_stacked") and not hasattr(expert_w2[0], "fp8_weight_stacked_transpose"):
939+
w2_quant = w2_quant.contiguous().transpose([0, 2, 1]).contiguous()
940+
w2_scale = w2_scale.contiguous().transpose([0, 2, 1]).contiguous()
941+
917942
# quant o2
918943
with paddle.amp.auto_cast(False):
919944
unzipped_probs = unzipped_probs.squeeze(-1)

0 commit comments

Comments
 (0)