Skip to content

Commit 7ad7cee

Browse files
authored
[BENCH] Remove TMA workaround in swiglu (triton-lang#6711)
TMA is no longer needed in this kernel after the convert layout cost model added in triton-lang#6699 Also fix test_swiglu.py. It broke after triton-lang#6703 <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it already has test coverage. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent e3f9f43 commit 7ad7cee

File tree

3 files changed

+6
-29
lines changed

3 files changed

+6
-29
lines changed

bench/tests/test_swiglu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from .test_routing import init_data as init_routing_data
8-
from .test_routing import ref_expt_data
98

109
# ---------------
1110
# initialize data
@@ -33,8 +32,7 @@ def test_op(M, N, limit, device, alpha=0.5):
3332
n_expts_act = 2
3433
logits = init_routing_data(M, n_expts_tot).detach()
3534
routing_data, _, _ = routing_torch(logits, n_expts_act)
36-
expt_data = ref_expt_data(routing_data, M * n_expts_act, block_m=128)
37-
n_tokens = expt_data[2 * n_expts_tot].sum()
35+
n_tokens = routing_data.expt_hist.sum()
3836

3937
# initialize data
4038
x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16)

bench/triton_bench/swiglu.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from triton_bench.numerics import InFlexData, OutFlexData
33
import torch
44
import triton
5-
from triton.tools.tensor_descriptor import TensorDescriptor
65
from .swiglu_details._swiglu import _swiglu
76
from triton_bench import target_info
87
from .matmul_ogs_details.metadata import compute_metadata
@@ -35,17 +34,6 @@ def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
3534
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
3635
num_warps = 4
3736
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
38-
# TMA descriptors
39-
out_desc = None
40-
a_desc = None
41-
if target_info.cuda_capability_geq(9, 0) and flex_ctx.out_data.actual_scale is not None:
42-
# We need TMA to store the outputs otherwise Triton will aggressively removing layout conversions at
43-
# the cost of duplicating too much compute. With TMA, the layout conversion gets folded into the TMA store,
44-
# and the duplication doesn't occur.
45-
assert out.shape[-1] * out.element_size() % 16 == 0
46-
out_desc = TensorDescriptor.from_tensor(out, (BLOCK_M, BLOCK_N))
47-
assert a.shape[-1] * a.element_size() % 16 == 0
48-
a_desc = TensorDescriptor.from_tensor(a, (BLOCK_M, 2 * BLOCK_N))
4937
# launch semi-persistent kernel
5038
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
5139
num_sms = target_info.num_sms()
@@ -64,12 +52,10 @@ def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
6452
if routing_data is not None:
6553
expt_data = compute_metadata(routing_data, M, BLOCK_M).buffer
6654
_swiglu[grid](
67-
out_desc,
6855
flex_ctx.out_data.reinterpret(out),
6956
flex_ctx.out_data.expected_scale,
7057
flex_ctx.out_data.actual_scale,
7158
flex_ctx.out_data.checksum_scale,
72-
a_desc,
7359
flex_ctx.inp_data.reinterpret(a),
7460
flex_ctx.inp_data.scale,
7561
alpha,

bench/triton_bench/swiglu_details/_swiglu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@ def swiglu_launch_metadata(grid, kernel, args):
3535

3636

3737
@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
38-
def _swiglu(out_desc, Out, OutExpectedScale, OutActualScale, OutChecksumScale, a_desc, A, AScale, alpha, M, N,
39-
stride_am, stride_an, stride_outm, stride_outn, limit: tl.constexpr, ExptData, NUM_EXPERTS: tl.constexpr,
40-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS,
41-
flexpoint_saturate_inf: tl.constexpr):
38+
def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
39+
stride_outm, stride_outn, limit: tl.constexpr, ExptData, NUM_EXPERTS: tl.constexpr, BLOCK_M: tl.constexpr,
40+
BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
4241
if ExptData is not None:
4342
M = tl.load(ExptData + 2 * NUM_EXPERTS)
4443
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
@@ -61,8 +60,6 @@ def _swiglu(out_desc, Out, OutExpectedScale, OutActualScale, OutChecksumScale, a
6160
# load a
6261
packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
6362
packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
64-
if a_desc is not None:
65-
a_packed = a_desc.load([pid_m * BLOCK_M, pid_n * 2 * BLOCK_N])
6663
if EVEN_N:
6764
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
6865
else:
@@ -91,11 +88,7 @@ def _swiglu(out_desc, Out, OutExpectedScale, OutActualScale, OutChecksumScale, a
9188
out = float_to_flex(out, out_expected_scale,
9289
None, # ActualScale: local absmax is tracked and updated after the loop
9390
OutChecksumScale, None, Out, flexpoint_saturate_inf)
94-
95-
if out_desc is not None:
96-
out_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], out.to(Out.dtype.element_ty))
97-
else:
98-
mask = mask_m[:, None] if EVEN_N else mask_m[:, None] and mask_n[None, :]
99-
tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask)
91+
mask = mask_m[:, None] if EVEN_N else mask_m[:, None] and mask_n[None, :]
92+
tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask)
10093

10194
update_scale(local_max, OutActualScale, Out)

0 commit comments

Comments
 (0)