Skip to content

Commit 9cdcf1d

Browse files
implement persistent loop based rmsnorm kernel (#676)
* implement persistent loop based rmsnorm kernel * remove comment in autotune configs and always assume Triton kernel use GPU * remove comments * slightly improved perf * add get_num_sms functions, and change triton_rmsnorn call interface * num_stages =1 is the best * update format
1 parent 40a9963 commit 9cdcf1d

File tree

1 file changed

+93
-90
lines changed

1 file changed

+93
-90
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 93 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ def is_hip():
1616
return triton.runtime.driver.active.get_current_target().backend == "hip"
1717

1818

19+
def get_num_sms():
20+
current_device_index = torch.cuda.current_device()
21+
current_device = torch.cuda.get_device_properties(current_device_index)
22+
num_sms = current_device.multi_processor_count
23+
return num_sms
24+
25+
1926
def get_cuda_autotune_config():
2027
return [
2128
triton.Config({}, num_warps=4, num_stages=1),
@@ -25,9 +32,7 @@ def get_cuda_autotune_config():
2532

2633

2734
def get_hip_autotune_config():
28-
return [
29-
triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in product([0, 1, 2, 4], [8, 16])
30-
]
35+
return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])]
3136

3237

3338
def get_autotune_config():
@@ -37,102 +42,92 @@ def get_autotune_config():
3742
return get_hip_autotune_config()
3843

3944

40-
# accumulate sum of squares for a row in a blocked manner
41-
@triton.jit
42-
def accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx):
43-
col_offsets = tl.arange(0, BLOCK_SIZE)
44-
sum_squares = tl.zeros([1], dtype=tl.float32)
45-
row_input_ptr = input_ptr + row_idx * input_row_stride
46-
47-
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
48-
for start in range(0, n_cols_blks * BLOCK_SIZE, BLOCK_SIZE):
49-
cols = start + col_offsets
50-
input_ptrs = row_input_ptr + cols
51-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
52-
x = tl.load(input_ptrs)
53-
sum_squares += tl.sum(x * x, axis=0)
54-
55-
# loop peeling for mask
56-
cols = n_cols_blks * BLOCK_SIZE + col_offsets
57-
mask = cols < n_cols
58-
input_ptrs = row_input_ptr + cols
59-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
60-
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
61-
sum_squares += tl.sum(x * x, axis=0)
62-
63-
return sum_squares
64-
65-
66-
# apply normalization to each block of the row
67-
@triton.jit
68-
def apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor,
69-
BLOCK_SIZE, row_idx):
70-
col_offsets = tl.arange(0, BLOCK_SIZE)
71-
row_input_ptr = input_ptr + row_idx * input_row_stride
72-
row_output_ptr = output_ptr + row_idx * output_row_stride
73-
74-
for start in range(0, n_cols, BLOCK_SIZE):
75-
cols = start + col_offsets
76-
mask = cols < n_cols
77-
input_ptrs = row_input_ptr + cols
78-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
79-
g_ptrs = g_ptr + cols
80-
output_ptrs = row_output_ptr + cols
81-
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
82-
g = tl.load(g_ptrs, mask=mask, other=0.0)
83-
rms_norm = x * norm_factor * g
84-
tl.store(output_ptrs, rms_norm, mask=mask)
85-
86-
87-
# Main kernel with both blocked and non-blocked versions based on BLOCK_SIZE
8845
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
8946
@triton.jit
9047
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
9148
BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr):
92-
row_idx = tl.program_id(0) # Each program instance handles one row
49+
row_start = tl.program_id(0)
9350
col_offsets = tl.arange(0, BLOCK_SIZE)
51+
tl.assume(input_row_stride >= 0)
52+
tl.assume(output_row_stride >= 0)
9453

9554
if USE_BLOCKED:
96-
# Blocked Approach: Accumulate sum of squares and normalize in chunks
97-
sum_squares = accumulate_sum_squares(input_ptr, input_row_stride, n_cols, BLOCK_SIZE, row_idx)
98-
mean_square = sum_squares / n_cols
99-
norm_factor = tl.rsqrt(mean_square + epsilon)
10055

101-
# Apply normalization
102-
apply_normalization(input_ptr, output_ptr, g_ptr, input_row_stride, output_row_stride, n_cols, norm_factor,
103-
BLOCK_SIZE, row_idx)
56+
# Persistent loop for rows
57+
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1):
58+
row_input_ptr = input_ptr + row_idx * input_row_stride
59+
row_output_ptr = output_ptr + row_idx * output_row_stride
60+
61+
# Accumulate sum of squares
62+
sum_squares = tl.zeros([1], dtype=tl.float32)
63+
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
64+
for blk_idx in range(n_cols_blks, num_stages=1):
65+
cols = blk_idx * BLOCK_SIZE + col_offsets
66+
input_ptrs = row_input_ptr + cols
67+
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
68+
x = tl.load(input_ptrs)
69+
sum_squares += tl.sum(x * x, axis=0)
70+
71+
# Handle remainder
72+
cols = n_cols_blks * BLOCK_SIZE + col_offsets
73+
mask = cols < n_cols
74+
input_ptrs = row_input_ptr + cols
75+
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
76+
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
77+
sum_squares += tl.sum(x * x, axis=0)
78+
79+
# Compute normalization factor
80+
mean_square = sum_squares / n_cols
81+
norm_factor = tl.rsqrt(mean_square + epsilon)
82+
83+
# Normalize and write output
84+
for blk_idx in range(n_cols_blks, num_stages=1):
85+
cols = blk_idx * BLOCK_SIZE + col_offsets
86+
input_ptrs = row_input_ptr + cols
87+
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
88+
x = tl.load(input_ptrs)
89+
g_ptrs = g_ptr + cols
90+
g = tl.load(g_ptrs)
91+
rms_norm = x * norm_factor * g
92+
output_ptrs = row_output_ptr + cols
93+
tl.store(output_ptrs, rms_norm)
94+
95+
# Handle remainder
96+
cols = n_cols_blks * BLOCK_SIZE + col_offsets
97+
mask = cols < n_cols
98+
input_ptrs = row_input_ptr + cols
99+
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
100+
g_ptrs = g_ptr + cols
101+
g = tl.load(g_ptrs, mask=mask, other=0.0)
102+
rms_norm = x * norm_factor * g
103+
output_ptrs = row_output_ptr + cols
104+
tl.store(output_ptrs, rms_norm, mask=mask)
104105

105106
else:
106107
mask = col_offsets < n_cols
107-
tl.assume(input_row_stride >= 0)
108-
tl.assume(output_row_stride >= 0)
109-
row_start_ptr = input_ptr + row_idx * input_row_stride
110-
input_ptrs = row_start_ptr + col_offsets
111-
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
112-
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
113-
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
114-
row_norm = row * row
115-
row_norm = tl.sum(row_norm, axis=-1)
116-
row_norm = row_norm / n_cols
117-
row_norm = row_norm + epsilon
118-
row_norm = tl.rsqrt(row_norm)
119-
rms_norm = row * row_norm
120-
rms_norm = rms_norm * g
121-
122-
output_row_start_ptr = output_ptr + row_idx * output_row_stride
123-
output_ptrs = output_row_start_ptr + col_offsets
124-
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
125-
tl.store(output_ptrs, rms_norm, mask=mask)
126-
127-
128-
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6):
129-
BLOCK_SIZE = blk_size
130-
# Use blocked approach if BLOCK_SIZE larger than 65536 // x.element_size()
131-
USE_BLOCKED = n_cols > BLOCK_SIZE
132-
133-
NUM_PRGMS = n_rows
108+
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1):
109+
row_start_ptr = input_ptr + row_idx * input_row_stride
110+
input_ptrs = row_start_ptr + col_offsets
111+
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
112+
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
113+
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
114+
row_norm = row * row
115+
row_norm = tl.sum(row_norm, axis=-1)
116+
row_norm = row_norm / n_cols
117+
row_norm = row_norm + epsilon
118+
row_norm = tl.rsqrt(row_norm)
119+
rms_norm = row * row_norm
120+
rms_norm = rms_norm * g
121+
122+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
123+
output_ptrs = output_row_start_ptr + col_offsets
124+
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
125+
tl.store(output_ptrs, rms_norm, mask=mask)
126+
127+
128+
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS, epsilon=1e-6):
134129
grid = lambda meta: (NUM_PRGMS, )
135-
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, USE_BLOCKED, NUM_PRGMS)
130+
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, blk_size, USE_BLOCKED, NUM_PRGMS)
136131

137132
return y
138133

@@ -165,8 +160,10 @@ def test_rmsnorm(M, N):
165160
n_rows, n_cols = x.shape
166161
MAX_FUSED_SIZE = 65536 // x.element_size()
167162
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
163+
USE_BLOCKED = n_cols > blk_size
164+
NUM_PRGMS = min(n_rows, get_num_sms())
168165
g = torch.ones((1, N), device='cuda')
169-
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)
166+
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
170167

171168
y_torch = torch_rmsnorm(x, g)
172169

@@ -219,16 +216,20 @@ def benchmark(M, N, provider):
219216
n_rows, n_cols = x.shape
220217
MAX_FUSED_SIZE = 65536 // x.element_size()
221218
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
219+
USE_BLOCKED = n_cols > blk_size
220+
NUM_PRGMS = min(n_rows, get_num_sms())
222221
stream = torch.cuda.Stream()
223222
torch.cuda.set_stream(stream)
224223
g = torch.ones((1, N), device='cuda')
225224
if provider == 'torch':
226225
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g))
227226
if provider == 'triton':
228-
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size))
227+
ms = triton.testing.do_bench(
228+
lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS))
229229
global verbose
230230
if verbose:
231231
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})')
232+
print(f'time: {ms}')
232233
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
233234
return gbps(ms)
234235

@@ -266,8 +267,10 @@ def main():
266267
n_rows, n_cols = x.shape
267268
MAX_FUSED_SIZE = 65536 // x.element_size()
268269
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
270+
USE_BLOCKED = n_cols > blk_size
271+
NUM_PRGMS = min(n_rows, get_num_sms())
269272
g = torch.ones((1, args.N_start), device='cuda')
270-
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)
273+
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS)
271274
else:
272275
verbose = args.v
273276
run_benchmark(args)

0 commit comments

Comments
 (0)