Skip to content

use ds gemm #10879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dsv3_dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import numpy
import paddle
import paddle.nn.functional as F
Expand All @@ -26,8 +28,14 @@ def swiglu(x, y=None):
return F.silu(x) * y


USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true"

try:
from paddle.incubate.fp8 import deep_gemm
if USE_DS_GEMM:
import deep_gemm
else:
from paddle.incubate.fp8 import deep_gemm

except:
pass

Expand All @@ -43,6 +51,12 @@ def swiglu(x, y=None):
def kitchen_fp8_gemm(
x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16
):
if USE_DS_GEMM:
if out is None:
out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out)
return out

if out is not None:
accumulate = True
out_dtype = out.dtype
Expand Down Expand Up @@ -126,7 +140,7 @@ def forward(ctx, x, custom_map):
)

out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out)
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])

# save for bwd
Expand Down Expand Up @@ -223,7 +237,7 @@ def forward(ctx, x, custom_map):

# compute out = mm(x, w_t)
out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out)
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])

ctx.save_for_backward(x, weight)
Expand Down Expand Up @@ -263,7 +277,7 @@ def backward(ctx, dout):
)

dx = paddle.empty([dout_fp8.shape[0], w_fp8.shape[0]], dout.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx)
dx = dx.reshape(dx_orig_shape)

# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
Expand Down Expand Up @@ -384,7 +398,7 @@ def common_fp8_mlp_fwd(x, w1, w2):
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
)
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1)

# ===== o2 = swiglu(o1) =====
o2 = swiglu(o1)
Expand All @@ -397,7 +411,7 @@ def common_fp8_mlp_fwd(x, w1, w2):
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
)
o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3)
return x_fp8, x_scale, o3


Expand All @@ -406,7 +420,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
)
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1)

# ===== [recompute] o2 = swiglu(o1) =====
o2 = swiglu(o1)
Expand All @@ -428,7 +442,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=False
)
do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2)

# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
o2 = padding(o2, 0)
Expand Down Expand Up @@ -488,7 +502,7 @@ def common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, apply_ba
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False
)
dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype)
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx, num_sms=112)
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx)

# ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8)
if apply_backward_hook:
Expand Down Expand Up @@ -801,7 +815,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert):
split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1)
else:
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices, num_sms=112
(x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices
)

self.input_fp8 = x_fp8
Expand Down Expand Up @@ -844,7 +858,7 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, self.tokens_per_expert, o3)
else:
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(o2_fp8, o2_scale), (w2_quant, w2_scale), o3, m_indices=self.m_indices, num_sms=112
(o2_fp8, o2_scale), (w2_quant, w2_scale), o3, m_indices=self.m_indices
)
return o3, unzipped_probs

Expand Down Expand Up @@ -882,7 +896,6 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
(bw_w2_quant, bw_w2_scale),
do2_s,
m_indices=self.m_indices,
num_sms=112,
)

with paddle.amp.auto_cast(False):
Expand Down Expand Up @@ -924,7 +937,7 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, self.tokens_per_expert, dx)
else:
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices, num_sms=112
(do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices
)

return dx
Expand Down
Loading