Skip to content

Commit dad134e

Browse files
authored
optimize scale transpose (#10810)
1 parent be62526 commit dad134e

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,9 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
679679
else:
680680
# quant x_bf16
681681
x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
682-
x_bf16, output_scale_transpose=False, quant_method="1x128", input_transpose=False
682+
x_bf16, output_scale_transpose=True, quant_method="1x128", input_transpose=False
683683
)
684-
685-
x_scale = paddle.transpose(paddle.transpose(x_scale, [1, 0]).contiguous(), [1, 0])
684+
x_scale = x_scale.T
686685

687686
# compute gemm
688687
o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype)
@@ -730,9 +729,8 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
730729
o3_shape = [o2_fp8.shape[0], w2_quant.shape[1]]
731730
if o3 is not None:
732731
assert o3.shape == o3_shape, "{} vs {}".format(o3.shape, o3_shape)
733-
o3.zero_()
734732
else:
735-
o3 = paddle.zeros(o3_shape, dtype=o1.dtype)
733+
o3 = paddle.empty(o3_shape, dtype=o1.dtype)
736734
if numpy.prod(o2_fp8.shape) != 0:
737735
if self.is_split_group_gemm:
738736
split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_sacle, self.tokens_per_expert, o3)
@@ -756,18 +754,17 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
756754

757755
# compute gemm
758756
unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
759-
unzipped_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False
757+
unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False
760758
)
759+
unzipped_grad_scale = unzipped_grad_scale.T
760+
761761
do2_s = paddle.empty([unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], dtype=unzipped_grad.dtype)
762762
if numpy.prod(unzipped_grad_fp8.shape) != 0:
763763
if self.is_split_group_gemm:
764764
split_group_gemm(
765765
unzipped_grad_fp8, unzipped_grad_scale, bw_w2_quant, bw_w2_scale, self.tokens_per_expert, do2_s
766766
)
767767
else:
768-
unzipped_grad_scale = paddle.transpose(
769-
paddle.transpose(unzipped_grad_scale, [1, 0]).contiguous(), [1, 0]
770-
)
771768
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
772769
(unzipped_grad_fp8, unzipped_grad_scale),
773770
(bw_w2_quant, bw_w2_scale),
@@ -801,21 +798,19 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
801798

802799
# quant do1
803800
do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
804-
do1, output_scale_transpose=False, quant_method="1x128", input_transpose=False
801+
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False
805802
)
806-
803+
do1_scale = do1_scale.T
807804
# compute gemm
808805
dx_shape = [do1_fp8.shape[0], bw_w1_quant.shape[1]]
809806
if dx is None:
810807
dx = paddle.empty(shape=dx_shape, dtype=do1.dtype)
811808
else:
812809
assert dx.shape == dx_shape, f"{dx.shape} vs {dx_shape}"
813-
dx.zero_()
814810
if numpy.prod(do1_fp8.shape) != 0:
815811
if self.is_split_group_gemm:
816812
split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, self.tokens_per_expert, dx)
817813
else:
818-
do1_scale = paddle.transpose(paddle.transpose(do1_scale, [1, 0]).contiguous(), [1, 0])
819814
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
820815
(do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices, num_sms=112
821816
)

0 commit comments

Comments
 (0)