Skip to content

Commit fe5e87a

Browse files
authored
lock gemm sm 112 (#10805)
1 parent e3f1b62 commit fe5e87a

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def forward(ctx, x, weight):
120120
)
121121

122122
out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
123-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out)
123+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out, num_sms=112)
124124
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])
125125

126126
# save for bwd
@@ -194,7 +194,7 @@ def forward(ctx, x, weight):
194194

195195
# compute out = mm(x, w_t)
196196
out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
197-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out)
197+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out, num_sms=112)
198198
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])
199199

200200
ctx.save_for_backward(x, weight)
@@ -230,7 +230,7 @@ def backward(ctx, dout):
230230
)
231231

232232
dx = paddle.empty([dout_fp8.shape[0], w_fp8.shape[0]], dout.dtype)
233-
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_sacle), dx)
233+
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_sacle), dx, num_sms=112)
234234
dx = dx.reshape(dx_orig_shape)
235235

236236
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -267,7 +267,7 @@ def fp8_mlp_fwd(x, w1, w2):
267267
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
268268
)
269269
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype)
270-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1)
270+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1, num_sms=112)
271271

272272
# ===== o2 = swiglu(o1) =====
273273
o2 = swiglu(o1)
@@ -280,7 +280,7 @@ def fp8_mlp_fwd(x, w1, w2):
280280
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True
281281
)
282282
o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype)
283-
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3)
283+
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3, num_sms=112)
284284
if len(x_orig_shape) > 2:
285285
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
286286

@@ -297,7 +297,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
297297
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
298298
)
299299
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
300-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1)
300+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1, num_sms=112)
301301

302302
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
303303
x_dequant_fp16 = padding(x_dequant_fp16, 0)
@@ -326,7 +326,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
326326
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=False
327327
)
328328
do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype)
329-
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2)
329+
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale.T), (w2_fp8, w2_scale), do2, num_sms=112)
330330

331331
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
332332
o2 = padding(o2, 0)
@@ -383,7 +383,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
383383
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False
384384
)
385385
dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype)
386-
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_sacle), dx)
386+
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_sacle), dx, num_sms=112)
387387
if len(x_orig_shape) > 2:
388388
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
389389

@@ -577,7 +577,7 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
577577
x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T
578578

579579
deep_gemm.gemm_fp8_fp8_bf16_nt(
580-
(x_fp8[start_idx:end_idx], x_scale_tma_align), (w_fp8[i], w_scale[i]), gemm_out[start_idx:end_idx]
580+
(x_fp8[start_idx:end_idx], x_scale_tma_align), (w_fp8[i], w_scale[i]), gemm_out[start_idx:end_idx], num_sms=112
581581
)
582582

583583
start_idx = end_idx
@@ -681,7 +681,7 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
681681
split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1)
682682
else:
683683
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
684-
(x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices
684+
(x_fp8, x_scale), (w1_t_quant, w1_t_scale), o1, m_indices=self.m_indices, num_sms=112
685685
)
686686

687687
if self.dequant_input:
@@ -728,7 +728,7 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
728728
split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_sacle, self.tokens_per_expert, o3)
729729
else:
730730
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
731-
(o2_fp8, o2_scale), (w2_quant, w2_sacle), o3, m_indices=self.m_indices
731+
(o2_fp8, o2_scale), (w2_quant, w2_sacle), o3, m_indices=self.m_indices, num_sms=112
732732
)
733733
return o3, unzipped_probs
734734

@@ -763,6 +763,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
763763
(bw_w2_quant, bw_w2_scale),
764764
do2_s,
765765
m_indices=self.m_indices,
766+
num_sms=112
766767
)
767768

768769
with paddle.amp.auto_cast(False):
@@ -806,7 +807,7 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
806807
else:
807808
do1_scale = paddle.transpose(paddle.transpose(do1_scale, [1, 0]).contiguous(), [1, 0])
808809
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
809-
(do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices
810+
(do1_fp8, do1_scale), (bw_w1_quant, bw_w1_scale), dx, m_indices=self.m_indices, num_sms=112
810811
)
811812

812813
return dx

0 commit comments

Comments
 (0)