Skip to content

Commit f9c97b3

Browse files
authored
[kernels] revert bias subtiling changes (#7232)
regresses moe matmul performance
1 parent 68a24ff commit f9c97b3

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ def _p_matmul_ogs(
384384
block_shape=[BLOCK_M, OUT_BLOCK_N],
385385
)
386386

387+
# bias + scale
388+
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
389+
mask_n = offs_y_n < N
390+
if B is not None:
391+
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
392+
if pid_k1 == 0:
393+
bias = tl.load(BPtrs, mask=mask_n, other=0)
394+
else:
395+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
396+
else:
397+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
387398
if Betas is not None:
388399
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
389400
else:
@@ -399,15 +410,21 @@ def _p_matmul_ogs(
399410
w_scale = load_scale(WScale)
400411

401412
accs = (acc,)
413+
biases = (bias,)
402414

403415
if SUBTILE_FACTOR >= 2:
404416
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
405417
accs = (acc0, acc1)
418+
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
419+
biases = (bias0, bias1)
406420

407421
if SUBTILE_FACTOR >= 4:
408422
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
409423
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
410424
accs = (acc00, acc01, acc10, acc11)
425+
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
426+
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
427+
biases = (bias00, bias01, bias10, bias11)
411428

412429
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
413430
tl.static_assert(len(accs) == SUBTILE_FACTOR)
@@ -419,18 +436,7 @@ def _p_matmul_ogs(
419436
if SWAP_XW:
420437
acc_tile = acc_tile.T
421438

422-
if B is not None:
423-
offs_y_n = off_n1 + EPILOGUE_BLOCK_N * a_i + tl.arange(0, EPILOGUE_BLOCK_N)
424-
mask_n = offs_y_n < N
425-
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
426-
if pid_k1 == 0:
427-
bias = tl.load(BPtrs, mask=mask_n, other=0)
428-
else:
429-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
430-
else:
431-
bias = tl.full([EPILOGUE_BLOCK_N], 0, dtype=tl.float32)
432-
433-
acc_tile = acc_tile + bias[None, :] * betas[:, None]
439+
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
434440
if out_alpha is not None:
435441
acc_tile *= out_alpha
436442

0 commit comments

Comments
 (0)