|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def _fwd_fused_add_rmsnorm( |
| 8 | + original, |
| 9 | + residual, |
| 10 | + weight, |
| 11 | + original_stride0, |
| 12 | + original_stride1, |
| 13 | + residual_stride0, |
| 14 | + residual_stride1, |
| 15 | + N, # number of columns in X |
| 16 | + eps, |
| 17 | + BLOCK_SIZE: tl.constexpr, |
| 18 | +): |
| 19 | + block_id = tl.program_id(0) |
| 20 | + # data's base address of this block |
| 21 | + _original = original + block_id * original_stride0 |
| 22 | + _residual = residual + block_id * residual_stride0 |
| 23 | + |
| 24 | + # avoid repeat loading from gmem to smem |
| 25 | + # in some very large size, have better performance |
| 26 | + if N <= BLOCK_SIZE: |
| 27 | + # data's offset address of this block |
| 28 | + range = tl.arange(0, BLOCK_SIZE) |
| 29 | + _original_offset = range * original_stride1 |
| 30 | + _residual_offset = range * residual_stride1 |
| 31 | + _weight_offset = range |
| 32 | + |
| 33 | + # data's pointers of this block |
| 34 | + _original_ptr = _original + _original_offset |
| 35 | + _residual_ptr = _residual + _residual_offset |
| 36 | + _weight_ptr = weight + _weight_offset |
| 37 | + |
| 38 | + # load data from memory |
| 39 | + mask = range < N |
| 40 | + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) |
| 41 | + residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) |
| 42 | + weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) |
| 43 | + |
| 44 | + # store (original + residual) to original |
| 45 | + original_cache = original_cache + residual_cache |
| 46 | + tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) |
| 47 | + |
| 48 | + # compute variance |
| 49 | + var = tl.sum(original_cache * original_cache) / N |
| 50 | + rstd = 1 / tl.sqrt(var + eps) |
| 51 | + residual_cache = original_cache * rstd * weight_cache |
| 52 | + |
| 53 | + # store rmsnorm(original + residual) back to residual |
| 54 | + tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask) |
| 55 | + else: |
| 56 | + sum_of_squares = tl.zeros([], dtype=tl.float32) |
| 57 | + for block_offset in range(0, N, BLOCK_SIZE): |
| 58 | + # data's offset address of this block |
| 59 | + range = tl.arange(0, BLOCK_SIZE) + block_offset |
| 60 | + _original_offset = range * original_stride1 |
| 61 | + _residual_offset = range * residual_stride1 |
| 62 | + |
| 63 | + # data's pointers of this block |
| 64 | + _original_ptr = _original + _original_offset |
| 65 | + _residual_ptr = _residual + _residual_offset |
| 66 | + |
| 67 | + # load data from memory |
| 68 | + mask = range < N |
| 69 | + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) |
| 70 | + residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) |
| 71 | + |
| 72 | + # store (original + residual) to original |
| 73 | + original_cache = original_cache + residual_cache |
| 74 | + tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) |
| 75 | + |
| 76 | + # compute sum_of_squares |
| 77 | + sum_of_squares += tl.sum(original_cache * original_cache) |
| 78 | + |
| 79 | + # compute variance |
| 80 | + var = sum_of_squares / N |
| 81 | + rstd = 1 / tl.sqrt(var + eps) |
| 82 | + |
| 83 | + for block_offset in range(0, N, BLOCK_SIZE): |
| 84 | + # data's offset address of this block |
| 85 | + range = tl.arange(0, BLOCK_SIZE) + block_offset |
| 86 | + _original_offset = range * original_stride1 |
| 87 | + _residual_offset = range * residual_stride1 |
| 88 | + _weight_offset = range |
| 89 | + |
| 90 | + # data's pointers of this block |
| 91 | + _original_ptr = _original + _original_offset |
| 92 | + _residual_ptr = _residual + _residual_offset |
| 93 | + _weight_ptr = weight + _weight_offset |
| 94 | + |
| 95 | + # load data from memory |
| 96 | + mask = range < N |
| 97 | + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) |
| 98 | + weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) |
| 99 | + |
| 100 | + # apply rmsnorm using pre-computed rstd |
| 101 | + original_cache = original_cache * rstd * weight_cache |
| 102 | + |
| 103 | + # store rmsnorm(original) back to residual |
| 104 | + tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask) |
| 105 | + |
| 106 | + |
| 107 | +def fused_add_rmsnorm_inplace( |
| 108 | + original: torch.Tensor, # [num_tokens, hidden_size] |
| 109 | + residual: torch.Tensor, |
| 110 | + weight: torch.Tensor, |
| 111 | + eps: float, |
| 112 | +): |
| 113 | + """ |
| 114 | + Perform fused add & rmsnorm |
| 115 | +
|
| 116 | + suppose the skip connection result is H(x) = F(x) + x, |
| 117 | + then F(x) is the residual, x is the original. |
| 118 | + Here original will be (residual + original), residual will be rmsnorm(residual + original) |
| 119 | + At first Layer, residual should be all zeros. |
| 120 | + """ |
| 121 | + # reshape input data into 2D tensor |
| 122 | + original_arg = original.view(-1, original.shape[-1]) |
| 123 | + residual_arg = residual.view(-1, residual.shape[-1]) |
| 124 | + |
| 125 | + assert original.data_ptr() == original_arg.data_ptr() |
| 126 | + assert residual.data_ptr() == residual_arg.data_ptr() |
| 127 | + |
| 128 | + M, N = original_arg.shape |
| 129 | + # Less than 64KB per feature: enqueue fused kernel |
| 130 | + MAX_FUSED_SIZE = 65536 // original.element_size() |
| 131 | + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
| 132 | + |
| 133 | + if N > BLOCK_SIZE: |
| 134 | + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
| 135 | + |
| 136 | + # heuristics for number of warps |
| 137 | + num_warps = min(max(BLOCK_SIZE // 256, 1), 4) |
| 138 | + num_warps = triton.next_power_of_2(num_warps) |
| 139 | + if BLOCK_SIZE > 16384: |
| 140 | + BLOCK_SIZE = 16384 |
| 141 | + |
| 142 | + # enqueue kernel |
| 143 | + _fwd_fused_add_rmsnorm[(M,)]( |
| 144 | + original_arg, |
| 145 | + residual_arg, |
| 146 | + weight, |
| 147 | + original_arg.stride(0), |
| 148 | + original_arg.stride(1), |
| 149 | + residual_arg.stride(0), |
| 150 | + residual_arg.stride(1), |
| 151 | + N, # number of columns in X |
| 152 | + eps, |
| 153 | + BLOCK_SIZE=BLOCK_SIZE, |
| 154 | + num_warps=num_warps, |
| 155 | + ) |
0 commit comments