|
| 1 | +# Original source: |
| 2 | +# https://github.com/tile-ai/tilelang/blob/main/examples/norm/test_rms_norm.py |
| 3 | +import tilelang |
| 4 | +import tilelang.language as T |
| 5 | +import torch |
| 6 | + |
| 7 | +tilelang.disable_cache() |
| 8 | + |
| 9 | + |
| 10 | +def rms_norm_splitk(M, N, blk_m, blk_k): |
| 11 | + dtype = "float" |
| 12 | + |
| 13 | + @T.prim_func |
| 14 | + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): |
| 15 | + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: |
| 16 | + A_shared = T.alloc_shared((blk_m, blk_k), dtype) |
| 17 | + A_local = T.alloc_fragment((blk_m, blk_k), dtype) |
| 18 | + A_powsum = T.alloc_fragment((blk_m,), dtype) |
| 19 | + |
| 20 | + num_k_step = T.ceildiv(N, blk_k) |
| 21 | + T.clear(A_local) |
| 22 | + for k in range(num_k_step): |
| 23 | + T.copy(A[bx * blk_m, k * blk_k], A_shared) |
| 24 | + for i, j in T.Parallel(blk_m, blk_k): |
| 25 | + A_local[i, j] += A_shared[i, j] * A_shared[i, j] |
| 26 | + T.reduce_sum(A_local, A_powsum, dim=1) |
| 27 | + for i in T.Parallel(blk_m): |
| 28 | + A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 |
| 29 | + |
| 30 | + for k in range(num_k_step): |
| 31 | + # reverse, better cache hit rate |
| 32 | + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) |
| 33 | + for i, j in T.Parallel(blk_m, blk_k): |
| 34 | + A_shared[i, j] *= A_powsum[i] |
| 35 | + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) |
| 36 | + |
| 37 | + return main |
| 38 | + |
| 39 | + |
| 40 | +def rms_norm(M, N, blk_m, dtype, variance_epsilon=1e-12): |
| 41 | + @T.prim_func |
| 42 | + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): |
| 43 | + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: |
| 44 | + A_shared = T.alloc_shared((blk_m, N), dtype) |
| 45 | + A_pow_local = T.alloc_fragment((blk_m, N), dtype) |
| 46 | + A_local = T.alloc_fragment((blk_m, N), dtype) |
| 47 | + A_powsum = T.alloc_fragment((blk_m,), dtype) |
| 48 | + |
| 49 | + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) |
| 50 | + T.copy(A_shared, A_local) |
| 51 | + for i, j in T.Parallel(blk_m, N): |
| 52 | + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] |
| 53 | + T.reduce_sum(A_pow_local, A_powsum, dim=1) |
| 54 | + for i in T.Parallel(blk_m): |
| 55 | + A_powsum[i] = T.rsqrt(A_powsum[i] / N) + variance_epsilon |
| 56 | + for i, j in T.Parallel(blk_m, N): |
| 57 | + A_local[i, j] *= A_powsum[i] |
| 58 | + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) |
| 59 | + |
| 60 | + return main |
| 61 | + |
| 62 | + |
| 63 | +TILELANG_DTYPE_MAP = { |
| 64 | + torch.bfloat16: "bfloat16", |
| 65 | + torch.float16: "float16", |
| 66 | + torch.float32: "float", |
| 67 | +} |
| 68 | + |
| 69 | + |
| 70 | +class TileLangRMSNorm(torch.nn.Module): |
| 71 | + def __init__(self, hidden_size, eps=1e-6): |
| 72 | + """ |
| 73 | + TileLangRMSNorm |
| 74 | + """ |
| 75 | + super().__init__() |
| 76 | + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) |
| 77 | + self.variance_epsilon = eps |
| 78 | + |
| 79 | + def forward(self, hidden_states): |
| 80 | + M, N = hidden_states.size() |
| 81 | + dtype = TILELANG_DTYPE_MAP[hidden_states.dtype] |
| 82 | + blk_m = 1 |
| 83 | + blk_k = 512 |
| 84 | + |
| 85 | + kernel = rms_norm(M, N, blk_m, dtype, self.variance_epsilon) |
| 86 | + jit_kernel = tilelang.compile( |
| 87 | + kernel, |
| 88 | + out_idx=[-1], |
| 89 | + target="cuda", |
| 90 | + pass_configs={ |
| 91 | + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, |
| 92 | + }, |
| 93 | + ) |
| 94 | + return lambda: jit_kernel(hidden_states) |
0 commit comments