Skip to content

Commit 2f471c0

Browse files
Rmsnnorm zero centered gamma (#703)
* add zero_centered_gamma as tl.constexpr for TE and test it * remove comments * add the missing torch parameter
1 parent d1cd40b commit 2f471c0

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ def get_autotune_config():
4646
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
4747
@triton.jit
4848
def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
49-
BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr):
49+
ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr,
50+
NUM_PRGMS: tl.constexpr):
5051
row_start = tl.program_id(0)
5152
col_offsets = tl.arange(0, BLOCK_SIZE)
52-
tl.assume(input_row_stride >= 0)
53-
tl.assume(output_row_stride >= 0)
54-
tl.assume(row_start >= 0)
53+
# as older version Triton doesn't support tl.assume and BUFF OPS, comment out for now
54+
# tl.assume(input_row_stride >= 0)
55+
# tl.assume(output_row_stride >= 0)
56+
# tl.assume(row_start >= 0)
5557

5658
if USE_BLOCKED:
5759

@@ -93,6 +95,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
9395
x = tl.load(input_ptrs).to(tl.float32)
9496
g_ptrs = g_ptr + cols
9597
g = tl.load(g_ptrs).to(tl.float32)
98+
if (ZERO_CENTERED_GAMMA):
99+
g += 1
96100
rms_norm = x * norm_factor * g
97101
output_ptrs = row_output_ptr + cols
98102
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty))
@@ -104,6 +108,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
104108
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
105109
g_ptrs = g_ptr + cols
106110
g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32)
111+
if (ZERO_CENTERED_GAMMA):
112+
g += 1
107113
rms_norm = x * norm_factor * g
108114
output_ptrs = row_output_ptr + cols
109115
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask)
@@ -123,29 +129,36 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
123129
rsigma_output_ptr = rsigma_ptr + row_idx
124130
tl.store(rsigma_output_ptr, norm_factor)
125131

132+
if (ZERO_CENTERED_GAMMA):
133+
g += 1
126134
rms_norm = row * norm_factor * g
127135

128136
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
129137
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
130138
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask)
131139

132140

133-
def triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS, epsilon=1e-6):
141+
def triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED, NUM_PRGMS,
142+
epsilon=1e-6):
134143
grid = lambda meta: (NUM_PRGMS, )
135-
rms_kernel[grid](y, x, g, rsigma, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, blk_size, USE_BLOCKED,
136-
NUM_PRGMS)
144+
rms_kernel[grid](y, x, g, rsigma, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, ZERO_CENTERED_GAMMA, blk_size,
145+
USE_BLOCKED, NUM_PRGMS)
137146

138147
return y, rsigma
139148

140149

141-
def torch_rmsnorm(x, g, epsilon=1e-6):
150+
def torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA, epsilon=1e-6):
142151
M, N = x.shape
143152
rms = torch.sqrt(torch.sum(x * x, dim=-1) * 1 / N)
144153
rsigma = 1.0 / rms
154+
if (ZERO_CENTERED_GAMMA):
155+
g += 1
145156
rms_norm = x * rsigma.unsqueeze(1) * g
157+
rms_norm = rms_norm.to(x.dtype)
146158
return rms_norm, rsigma
147159

148160

161+
@pytest.mark.parametrize('ZERO_CENTERED_GAMMA', [True, False])
149162
@pytest.mark.parametrize('M, N', [
150163
(1, 4),
151164
(2, 10),
@@ -156,20 +169,23 @@ def torch_rmsnorm(x, g, epsilon=1e-6):
156169
(3, 65536),
157170
(873, 1245),
158171
])
159-
def test_rmsnorm(M, N):
172+
def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA):
160173
torch.manual_seed(0)
161174
x = torch.randn(M, N, device='cuda')
162175
y = torch.zeros_like(x, device='cuda')
163176
rsigma = torch.empty((M, ), device='cuda', dtype=torch.float32)
177+
164178
n_rows, n_cols = x.shape
165179
MAX_FUSED_SIZE = 65536 // x.element_size()
166180
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
167181
USE_BLOCKED = n_cols > blk_size
168182
NUM_PRGMS = min(n_rows, get_num_sms())
169183
g = torch.ones((1, N), device='cuda')
170-
y_triton, rsigma_triton = triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
171184

172-
y_torch, rsigma_torch = torch_rmsnorm(x, g)
185+
y_triton, rsigma_triton = triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size,
186+
USE_BLOCKED, NUM_PRGMS)
187+
188+
y_torch, rsigma_torch = torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA)
173189

174190
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
175191
assert torch.allclose(rsigma_triton, rsigma_torch), (rsigma_triton, rsigma_torch)
@@ -249,11 +265,12 @@ def benchmark(M, N, provider, model=None):
249265
stream = torch.cuda.Stream()
250266
torch.cuda.set_stream(stream)
251267
g = torch.ones((1, N), device='cuda')
268+
ZERO_CENTERED_GAMMA = False
252269
if provider == 'torch':
253-
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g))
270+
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA))
254271
if provider == 'triton':
255-
ms = triton.testing.do_bench(
256-
lambda: triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS))
272+
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA,
273+
blk_size, USE_BLOCKED, NUM_PRGMS))
257274
global verbose
258275
if verbose:
259276
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})')
@@ -309,7 +326,8 @@ def main():
309326
USE_BLOCKED = n_cols > blk_size
310327
NUM_PRGMS = min(n_rows, get_num_sms())
311328
g = torch.ones((1, args.N_start), device='cuda')
312-
triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
329+
ZERO_CENTERED_GAMMA = True
330+
triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED, NUM_PRGMS)
313331
else:
314332
verbose = args.v
315333
run_benchmark(args)

0 commit comments

Comments
 (0)