Skip to content

Commit 9679203

Browse files
add rsigma to the output (#700)
* add rsigma to the output * make sure rsigma is fp32 * add rsigma to pytest
1 parent 1006241 commit 9679203

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_autotune_config():
4545

4646
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
4747
@triton.jit
48-
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
48+
def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
4949
BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr):
5050
row_start = tl.program_id(0)
5151
col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -61,8 +61,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
6161
row_output_ptr = output_ptr + row_idx * output_row_stride
6262

6363
# Accumulate sum of squares
64-
sum_squares = tl.zeros([1], dtype=tl.float32)
6564
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
65+
sum_squares: tl.float32 = 0.
6666
for blk_idx in tl.range(0, n_cols_blks, num_stages=2):
6767
cols = blk_idx * BLOCK_SIZE + col_offsets
6868
input_ptrs = row_input_ptr + cols
@@ -82,6 +82,9 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
8282
mean_square = sum_squares / n_cols
8383
norm_factor = tl.rsqrt(mean_square + epsilon)
8484

85+
# Store rsigma (norm_factor)
86+
tl.store(rsigma_ptr + row_idx, norm_factor)
87+
8588
# Normalize and write output
8689
for blk_idx in tl.range(0, n_cols_blks, num_stages=2):
8790
cols = blk_idx * BLOCK_SIZE + col_offsets
@@ -114,30 +117,33 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
114117
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
115118
row_norm = row * row
116119
row_norm = tl.sum(row_norm, axis=-1)
117-
row_norm = tl.math.rsqrt((row_norm / n_cols) + epsilon)
118-
rms_norm = row * row_norm * g
120+
norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon)
121+
122+
# Store rsigma (norm_factor)
123+
rsigma_output_ptr = rsigma_ptr + row_idx
124+
tl.store(rsigma_output_ptr, norm_factor)
125+
126+
rms_norm = row * norm_factor * g
119127

120128
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
121129
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
122130
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask)
123131

124132

125-
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS, epsilon=1e-6):
133+
def triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS, epsilon=1e-6):
126134
grid = lambda meta: (NUM_PRGMS, )
127-
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, blk_size, USE_BLOCKED, NUM_PRGMS)
135+
rms_kernel[grid](y, x, g, rsigma, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, blk_size, USE_BLOCKED,
136+
NUM_PRGMS)
128137

129-
return y
138+
return y, rsigma
130139

131140

132-
def torch_rmsnorm(x, g):
141+
def torch_rmsnorm(x, g, epsilon=1e-6):
133142
M, N = x.shape
134-
if hasattr(torch.nn, 'RMSNorm'):
135-
rms_norm = torch.nn.RMSNorm(N, device='cuda')
136-
return rms_norm(x)
137-
else:
138-
rms = torch.sqrt(torch.sum(x * x, dim=-1) * 1 / N)
139-
rms_norm = torch.div(x, rms.unsqueeze(1).repeat(1, N)) * g
140-
return rms_norm
143+
rms = torch.sqrt(torch.sum(x * x, dim=-1) * 1 / N)
144+
rsigma = 1.0 / rms
145+
rms_norm = x * rsigma.unsqueeze(1) * g
146+
return rms_norm, rsigma
141147

142148

143149
@pytest.mark.parametrize('M, N', [
@@ -154,17 +160,19 @@ def test_rmsnorm(M, N):
154160
torch.manual_seed(0)
155161
x = torch.randn(M, N, device='cuda')
156162
y = torch.zeros_like(x, device='cuda')
163+
rsigma = torch.empty((M, ), device='cuda', dtype=torch.float32)
157164
n_rows, n_cols = x.shape
158165
MAX_FUSED_SIZE = 65536 // x.element_size()
159166
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
160167
USE_BLOCKED = n_cols > blk_size
161168
NUM_PRGMS = min(n_rows, get_num_sms())
162169
g = torch.ones((1, N), device='cuda')
163-
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
170+
y_triton, rsigma_triton = triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
164171

165-
y_torch = torch_rmsnorm(x, g)
172+
y_torch, rsigma_torch = torch_rmsnorm(x, g)
166173

167174
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
175+
assert torch.allclose(rsigma_triton, rsigma_torch), (rsigma_triton, rsigma_torch)
168176

169177

170178
#Benchmark
@@ -232,6 +240,7 @@ def run_benchmark(args):
232240
def benchmark(M, N, provider, model=None):
233241
x = torch.randn(M, N, device='cuda', dtype=dtype)
234242
y = torch.zeros_like(x, device='cuda')
243+
rsigma = torch.empty((M, ), device='cuda', dtype=torch.float32)
235244
n_rows, n_cols = x.shape
236245
MAX_FUSED_SIZE = 65536 // x.element_size()
237246
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
@@ -244,7 +253,7 @@ def benchmark(M, N, provider, model=None):
244253
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g))
245254
if provider == 'triton':
246255
ms = triton.testing.do_bench(
247-
lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS))
256+
lambda: triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS))
248257
global verbose
249258
if verbose:
250259
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})')
@@ -293,13 +302,14 @@ def main():
293302
if args.no_benchmark:
294303
x = torch.randn(args.M_start, args.N_start, device='cuda')
295304
y = torch.zeros_like(x, device='cuda')
305+
rsigma = torch.empty((args.M_start, ), device='cuda', dtype=torch.float32)
296306
n_rows, n_cols = x.shape
297307
MAX_FUSED_SIZE = 65536 // x.element_size()
298308
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
299309
USE_BLOCKED = n_cols > blk_size
300310
NUM_PRGMS = min(n_rows, get_num_sms())
301311
g = torch.ones((1, args.N_start), device='cuda')
302-
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
312+
triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
303313
else:
304314
verbose = args.v
305315
run_benchmark(args)

0 commit comments

Comments
 (0)