Skip to content

Commit e1245da

Browse files
enable stream pipeline for persistent rmsnorm kernel (#686)
* explicit type conversion for load/store, and enable stream pipeline * tidy up non blocked kernel
1 parent c086d08 commit e1245da

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
5050
col_offsets = tl.arange(0, BLOCK_SIZE)
5151
tl.assume(input_row_stride >= 0)
5252
tl.assume(output_row_stride >= 0)
53+
tl.assume(row_start >= 0)
5354

5455
if USE_BLOCKED:
5556

@@ -61,68 +62,63 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
6162
# Accumulate sum of squares
6263
sum_squares = tl.zeros([1], dtype=tl.float32)
6364
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
64-
for blk_idx in range(n_cols_blks, num_stages=1):
65+
for blk_idx in tl.range(0, n_cols_blks, num_stages=2):
6566
cols = blk_idx * BLOCK_SIZE + col_offsets
6667
input_ptrs = row_input_ptr + cols
6768
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
68-
x = tl.load(input_ptrs)
69+
x = tl.load(input_ptrs).to(tl.float32)
6970
sum_squares += tl.sum(x * x, axis=0)
7071

7172
# Handle remainder
7273
cols = n_cols_blks * BLOCK_SIZE + col_offsets
7374
mask = cols < n_cols
7475
input_ptrs = row_input_ptr + cols
7576
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
76-
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
77+
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
7778
sum_squares += tl.sum(x * x, axis=0)
7879

7980
# Compute normalization factor
8081
mean_square = sum_squares / n_cols
8182
norm_factor = tl.rsqrt(mean_square + epsilon)
8283

8384
# Normalize and write output
84-
for blk_idx in range(n_cols_blks, num_stages=1):
85+
for blk_idx in tl.range(0, n_cols_blks, num_stages=2):
8586
cols = blk_idx * BLOCK_SIZE + col_offsets
8687
input_ptrs = row_input_ptr + cols
8788
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
88-
x = tl.load(input_ptrs)
89+
x = tl.load(input_ptrs).to(tl.float32)
8990
g_ptrs = g_ptr + cols
90-
g = tl.load(g_ptrs)
91+
g = tl.load(g_ptrs).to(tl.float32)
9192
rms_norm = x * norm_factor * g
9293
output_ptrs = row_output_ptr + cols
93-
tl.store(output_ptrs, rms_norm)
94+
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty))
9495

9596
# Handle remainder
9697
cols = n_cols_blks * BLOCK_SIZE + col_offsets
9798
mask = cols < n_cols
9899
input_ptrs = row_input_ptr + cols
99-
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
100+
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
100101
g_ptrs = g_ptr + cols
101-
g = tl.load(g_ptrs, mask=mask, other=0.0)
102+
g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32)
102103
rms_norm = x * norm_factor * g
103104
output_ptrs = row_output_ptr + cols
104-
tl.store(output_ptrs, rms_norm, mask=mask)
105+
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask)
105106

106107
else:
107108
mask = col_offsets < n_cols
108-
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=1):
109-
row_start_ptr = input_ptr + row_idx * input_row_stride
110-
input_ptrs = row_start_ptr + col_offsets
109+
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2):
110+
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
111111
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
112-
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
113-
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
112+
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32)
113+
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
114114
row_norm = row * row
115115
row_norm = tl.sum(row_norm, axis=-1)
116-
row_norm = row_norm / n_cols
117-
row_norm = row_norm + epsilon
118-
row_norm = tl.rsqrt(row_norm)
119-
rms_norm = row * row_norm
120-
rms_norm = rms_norm * g
121-
122-
output_row_start_ptr = output_ptr + row_idx * output_row_stride
123-
output_ptrs = output_row_start_ptr + col_offsets
116+
row_norm = tl.math.rsqrt((row_norm / n_cols) + epsilon)
117+
rms_norm = row * row_norm * g
118+
119+
output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
124120
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
125-
tl.store(output_ptrs, rms_norm, mask=mask)
121+
tl.store(output_ptrs, rms_norm.to(output_ptr.type.element_ty), mask=mask)
126122

127123

128124
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, USE_BLOCKED, NUM_PRGMS, epsilon=1e-6):

0 commit comments

Comments
 (0)