Skip to content

Commit adad9c3

Browse files
authored
[tilelang] Add gemm and rms_norm kernels (#514)
1 parent 2093c4a commit adad9c3

File tree

5 files changed

+231
-0
lines changed

5 files changed

+231
-0
lines changed

tritonbench/operators/gemm/operator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,16 @@ def _tlx_matmul(*args, **kwargs):
2626
raise RuntimeError("TLX not available in this Triton version")
2727

2828

29+
from tritonbench.utils.python_utils import try_import
30+
31+
with try_import("HAS_TILELANG"):
32+
from .tilelang import tilelang_matmul_func
33+
34+
2935
from tritonbench.utils.data_utils import get_production_shapes
3036
from tritonbench.utils.env_utils import (
3137
get_nvidia_gpu_model,
38+
is_cu130,
3239
is_cuda,
3340
is_fbcode,
3441
supports_tma,
@@ -472,6 +479,12 @@ def tlx_matmul(self, a, b, bias) -> Callable:
472479
else:
473480
return lambda: _tlx_matmul(a, b)
474481

482+
@register_benchmark(enabled=HAS_TILELANG and is_cu130())
483+
def tilelang_blackwell_matmul(self, a, b, bias) -> Callable:
484+
assert bias is None, "Tilelang does not support bias"
485+
assert a.dtype == torch.bfloat16, "Tilelang only supports bf16"
486+
return tilelang_matmul_func(a, b)
487+
475488
@register_x_val(label="(M, N, K)")
476489
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
477490
# x-value: computation intensity
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Original source: https://github.com/tile-ai/tilelang/blob/main/examples/gemm_sm100/gemm_tcgen5mma.py
2+
import tilelang
3+
import tilelang.language as T
4+
import torch
5+
6+
tilelang.disable_cache()
7+
8+
9+
def matmul(
10+
M,
11+
N,
12+
K,
13+
block_M,
14+
block_N,
15+
block_K,
16+
trans_A,
17+
trans_B,
18+
in_dtype,
19+
out_dtype,
20+
accum_dtype,
21+
num_stages,
22+
threads,
23+
):
24+
A_shape = (K, M) if trans_A else (M, K)
25+
B_shape = (N, K) if trans_B else (K, N)
26+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
27+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
28+
29+
@T.prim_func
30+
def main(
31+
A: T.Tensor(A_shape, in_dtype),
32+
B: T.Tensor(B_shape, in_dtype),
33+
C: T.Tensor((M, N), out_dtype),
34+
):
35+
with T.Kernel(
36+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
37+
) as (bx, by):
38+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
39+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
40+
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
41+
mbar = T.alloc_barrier(1)
42+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
43+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
44+
45+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
46+
T.copy(A[by * block_M, k * block_K], A_shared)
47+
T.copy(B[bx * block_N, k * block_K], B_shared)
48+
T.gemm(
49+
A_shared,
50+
B_shared,
51+
C_tmem,
52+
trans_A,
53+
trans_B,
54+
mbar=mbar,
55+
wg_wait=-1,
56+
clear_accum=k == 0,
57+
)
58+
T.mbarrier_wait_parity(mbar, k % 2)
59+
60+
T.copy(C_tmem, C_local)
61+
T.copy(C_local, C_shared)
62+
63+
T.copy(C_shared, C[by * block_M, bx * block_N])
64+
65+
return main
66+
67+
68+
TILELANG_DTYPE_MAP = {
69+
torch.bfloat16: "bfloat16",
70+
torch.float16: "float16",
71+
torch.float32: "float",
72+
}
73+
74+
75+
def tilelang_matmul_func(a, b):
76+
M, K = a.size()
77+
K, N = b.size()
78+
b_T = b.T.contiguous()
79+
block_M, block_N, block_K = 128, 256, 128
80+
trans_A, trans_B = False, True
81+
in_dtype = TILELANG_DTYPE_MAP[a.dtype]
82+
out_dtype = TILELANG_DTYPE_MAP[a.dtype]
83+
accum_dtype = "float"
84+
num_stages = 2
85+
threads = 256
86+
func = matmul(
87+
M,
88+
N,
89+
K,
90+
block_M,
91+
block_N,
92+
block_K,
93+
trans_A,
94+
trans_B,
95+
in_dtype,
96+
out_dtype,
97+
accum_dtype,
98+
num_stages,
99+
threads,
100+
)
101+
jit_kernel = tilelang.compile(
102+
func,
103+
out_idx=[2],
104+
target="cuda",
105+
pass_configs={
106+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
107+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
108+
},
109+
)
110+
return lambda: jit_kernel(a, b_T)

tritonbench/operators/rms_norm/operator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from tritonbench.utils.env_utils import is_hip
7+
from tritonbench.utils.python_utils import try_import
78

89
from tritonbench.utils.triton_op import (
910
BenchmarkOperator,
@@ -31,6 +32,9 @@
3132
except ModuleNotFoundError:
3233
QuackRMSNorm = None
3334

35+
with try_import("HAS_TILELANG"):
36+
from .tilelang import TileLangRMSNorm
37+
3438

3539
def parse_op_args(args: List[str]):
3640
parser = argparse.ArgumentParser()
@@ -153,6 +157,12 @@ def aiter(self, H, input, weight) -> Callable:
153157
self.aiter_rms_op = module
154158
return lambda: module(input)
155159

160+
@register_benchmark(enabled=HAS_TILELANG)
161+
def tilelang(self, H, input, weight) -> Callable:
162+
module = TileLangRMSNorm(hidden_size=H, eps=self.eps).to(self.device)
163+
module.weight = weight
164+
return module(input)
165+
156166
@register_x_val(label="(M, H)")
157167
def get_x_val(self, example_inputs) -> Tuple[int, int]:
158168
H = example_inputs[0]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Original source:
2+
# https://github.com/tile-ai/tilelang/blob/main/examples/norm/test_rms_norm.py
3+
import tilelang
4+
import tilelang.language as T
5+
import torch
6+
7+
tilelang.disable_cache()
8+
9+
10+
def rms_norm_splitk(M, N, blk_m, blk_k):
11+
dtype = "float"
12+
13+
@T.prim_func
14+
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
15+
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
16+
A_shared = T.alloc_shared((blk_m, blk_k), dtype)
17+
A_local = T.alloc_fragment((blk_m, blk_k), dtype)
18+
A_powsum = T.alloc_fragment((blk_m,), dtype)
19+
20+
num_k_step = T.ceildiv(N, blk_k)
21+
T.clear(A_local)
22+
for k in range(num_k_step):
23+
T.copy(A[bx * blk_m, k * blk_k], A_shared)
24+
for i, j in T.Parallel(blk_m, blk_k):
25+
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
26+
T.reduce_sum(A_local, A_powsum, dim=1)
27+
for i in T.Parallel(blk_m):
28+
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
29+
30+
for k in range(num_k_step):
31+
# reverse, better cache hit rate
32+
T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared)
33+
for i, j in T.Parallel(blk_m, blk_k):
34+
A_shared[i, j] *= A_powsum[i]
35+
T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k])
36+
37+
return main
38+
39+
40+
def rms_norm(M, N, blk_m, dtype, variance_epsilon=1e-12):
41+
@T.prim_func
42+
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
43+
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
44+
A_shared = T.alloc_shared((blk_m, N), dtype)
45+
A_pow_local = T.alloc_fragment((blk_m, N), dtype)
46+
A_local = T.alloc_fragment((blk_m, N), dtype)
47+
A_powsum = T.alloc_fragment((blk_m,), dtype)
48+
49+
T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared)
50+
T.copy(A_shared, A_local)
51+
for i, j in T.Parallel(blk_m, N):
52+
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
53+
T.reduce_sum(A_pow_local, A_powsum, dim=1)
54+
for i in T.Parallel(blk_m):
55+
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + variance_epsilon
56+
for i, j in T.Parallel(blk_m, N):
57+
A_local[i, j] *= A_powsum[i]
58+
T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :])
59+
60+
return main
61+
62+
63+
TILELANG_DTYPE_MAP = {
64+
torch.bfloat16: "bfloat16",
65+
torch.float16: "float16",
66+
torch.float32: "float",
67+
}
68+
69+
70+
class TileLangRMSNorm(torch.nn.Module):
71+
def __init__(self, hidden_size, eps=1e-6):
72+
"""
73+
TileLangRMSNorm
74+
"""
75+
super().__init__()
76+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
77+
self.variance_epsilon = eps
78+
79+
def forward(self, hidden_states):
80+
M, N = hidden_states.size()
81+
dtype = TILELANG_DTYPE_MAP[hidden_states.dtype]
82+
blk_m = 1
83+
blk_k = 512
84+
85+
kernel = rms_norm(M, N, blk_m, dtype, self.variance_epsilon)
86+
jit_kernel = tilelang.compile(
87+
kernel,
88+
out_idx=[-1],
89+
target="cuda",
90+
pass_configs={
91+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
92+
},
93+
)
94+
return lambda: jit_kernel(hidden_states)

tritonbench/utils/env_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def supports_tma():
7575
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
7676

7777

78+
def is_cu130():
79+
return is_cuda() and torch.version.cuda == "13.0"
80+
81+
7882
def set_env():
7983
# set cutlass dir
8084
# by default we use the cutlass version built with pytorch

0 commit comments

Comments
 (0)