diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 57fc729075e1..63930c9363dd 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy import paddle import paddle.nn.functional as F @@ -26,8 +28,14 @@ def swiglu(x, y=None): return F.silu(x) * y +USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true" + try: - from paddle.incubate.fp8 import deep_gemm + if USE_DS_GEMM: + import deep_gemm + else: + from paddle.incubate.fp8 import deep_gemm + except: pass @@ -43,6 +51,12 @@ def swiglu(x, y=None): def kitchen_fp8_gemm( x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out) + return out + if out is not None: accumulate = True out_dtype = out.dtype @@ -126,7 +140,7 @@ def forward(ctx, x, custom_map): ) out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out) out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) # save for bwd @@ -223,7 +237,7 @@ def forward(ctx, x, custom_map): # compute out = mm(x, w_t) out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out) out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) ctx.save_for_backward(x, weight) @@ -263,7 +277,7 @@ def backward(ctx, dout): ) dx = paddle.empty([dout_fp8.shape[0], w_fp8.shape[0]], dout.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx) dx = dx.reshape(dx_orig_shape) # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) @@ -384,7 +398,7 @@ def common_fp8_mlp_fwd(x, w1, w2): w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True ) o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1) # ===== o2 = swiglu(o1) ===== o2 = swiglu(o1) @@ -397,7 +411,7 @@ def common_fp8_mlp_fwd(x, w1, w2): w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True ) o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3) return x_fp8, x_scale, o3 @@ -406,7 +420,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True ) o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1) # ===== [recompute] o2 = swiglu(o1) ===== o2 = swiglu(o1) @@ -428,7 +442,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba w2, output_scale_transpose=False, quant_method="128x128", input_transpose=False ) do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2) # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) o2 = padding(o2, 0) @@ -488,7 +502,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False ) dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype) - deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx, num_sms=112) + deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx) # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) if apply_backward_hook: @@ -801,7 +815,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert): split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) else: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices, num_sms=112 + (x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices ) self.input_fp8 = x_fp8 @@ -844,7 +858,7 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1= split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, self.tokens_per_expert, o3) else: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (o2_fp8, o2_scale), (w2_quant, w2_scale), o3, m_indices=self.m_indices, num_sms=112 + (o2_fp8, o2_scale), (w2_quant, w2_scale), o3, m_indices=self.m_indices ) return o3, unzipped_probs @@ -882,7 +896,6 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False (bw_w2_quant, bw_w2_scale), do2_s, m_indices=self.m_indices, - num_sms=112, ) with paddle.amp.auto_cast(False): @@ -924,7 +937,7 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None): split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, self.tokens_per_expert, dx) else: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices, num_sms=112 + (do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices ) return dx