@@ -45,7 +45,7 @@ def get_autotune_config():
4545
4646@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
4747@triton .jit
48- def rms_kernel (output_ptr , input_ptr , g_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
48+ def rms_kernel (output_ptr , input_ptr , g_ptr , rsigma_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
4949 BLOCK_SIZE : tl .constexpr , USE_BLOCKED : tl .constexpr , NUM_PRGMS : tl .constexpr ):
5050 row_start = tl .program_id (0 )
5151 col_offsets = tl .arange (0 , BLOCK_SIZE )
@@ -61,8 +61,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
6161 row_output_ptr = output_ptr + row_idx * output_row_stride
6262
6363 # Accumulate sum of squares
64- sum_squares = tl .zeros ([1 ], dtype = tl .float32 )
6564 n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
65+ sum_squares : tl .float32 = 0.
6666 for blk_idx in tl .range (0 , n_cols_blks , num_stages = 2 ):
6767 cols = blk_idx * BLOCK_SIZE + col_offsets
6868 input_ptrs = row_input_ptr + cols
@@ -82,6 +82,9 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
8282 mean_square = sum_squares / n_cols
8383 norm_factor = tl .rsqrt (mean_square + epsilon )
8484
85+ # Store rsigma (norm_factor)
86+ tl .store (rsigma_ptr + row_idx , norm_factor )
87+
8588 # Normalize and write output
8689 for blk_idx in tl .range (0 , n_cols_blks , num_stages = 2 ):
8790 cols = blk_idx * BLOCK_SIZE + col_offsets
@@ -114,30 +117,33 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
114117 g = tl .load (g_ptr + col_offsets , mask = mask , other = 0.0 ).to (tl .float32 )
115118 row_norm = row * row
116119 row_norm = tl .sum (row_norm , axis = - 1 )
117- row_norm = tl .math .rsqrt ((row_norm / n_cols ) + epsilon )
118- rms_norm = row * row_norm * g
120+ norm_factor = tl .math .rsqrt ((row_norm / n_cols ) + epsilon )
121+
122+ # Store rsigma (norm_factor)
123+ rsigma_output_ptr = rsigma_ptr + row_idx
124+ tl .store (rsigma_output_ptr , norm_factor )
125+
126+ rms_norm = row * norm_factor * g
119127
120128 output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
121129 output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
122130 tl .store (output_ptrs , rms_norm .to (output_ptr .type .element_ty ), mask = mask )
123131
124132
125- def triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS , epsilon = 1e-6 ):
133+ def triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS , epsilon = 1e-6 ):
126134 grid = lambda meta : (NUM_PRGMS , )
127- rms_kernel [grid ](y , x , g , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , blk_size , USE_BLOCKED , NUM_PRGMS )
135+ rms_kernel [grid ](y , x , g , rsigma , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , blk_size , USE_BLOCKED ,
136+ NUM_PRGMS )
128137
129- return y
138+ return y , rsigma
130139
131140
132- def torch_rmsnorm (x , g ):
141+ def torch_rmsnorm (x , g , epsilon = 1e-6 ):
133142 M , N = x .shape
134- if hasattr (torch .nn , 'RMSNorm' ):
135- rms_norm = torch .nn .RMSNorm (N , device = 'cuda' )
136- return rms_norm (x )
137- else :
138- rms = torch .sqrt (torch .sum (x * x , dim = - 1 ) * 1 / N )
139- rms_norm = torch .div (x , rms .unsqueeze (1 ).repeat (1 , N )) * g
140- return rms_norm
143+ rms = torch .sqrt (torch .sum (x * x , dim = - 1 ) * 1 / N )
144+ rsigma = 1.0 / rms
145+ rms_norm = x * rsigma .unsqueeze (1 ) * g
146+ return rms_norm , rsigma
141147
142148
143149@pytest .mark .parametrize ('M, N' , [
@@ -154,17 +160,19 @@ def test_rmsnorm(M, N):
154160 torch .manual_seed (0 )
155161 x = torch .randn (M , N , device = 'cuda' )
156162 y = torch .zeros_like (x , device = 'cuda' )
163+ rsigma = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
157164 n_rows , n_cols = x .shape
158165 MAX_FUSED_SIZE = 65536 // x .element_size ()
159166 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
160167 USE_BLOCKED = n_cols > blk_size
161168 NUM_PRGMS = min (n_rows , get_num_sms ())
162169 g = torch .ones ((1 , N ), device = 'cuda' )
163- y_triton = triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
170+ y_triton , rsigma_triton = triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
164171
165- y_torch = torch_rmsnorm (x , g )
172+ y_torch , rsigma_torch = torch_rmsnorm (x , g )
166173
167174 assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
175+ assert torch .allclose (rsigma_triton , rsigma_torch ), (rsigma_triton , rsigma_torch )
168176
169177
170178#Benchmark
@@ -232,6 +240,7 @@ def run_benchmark(args):
232240 def benchmark (M , N , provider , model = None ):
233241 x = torch .randn (M , N , device = 'cuda' , dtype = dtype )
234242 y = torch .zeros_like (x , device = 'cuda' )
243+ rsigma = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
235244 n_rows , n_cols = x .shape
236245 MAX_FUSED_SIZE = 65536 // x .element_size ()
237246 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
@@ -244,7 +253,7 @@ def benchmark(M, N, provider, model=None):
244253 ms = triton .testing .do_bench (lambda : torch_rmsnorm (x , g ))
245254 if provider == 'triton' :
246255 ms = triton .testing .do_bench (
247- lambda : triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS ))
256+ lambda : triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS ))
248257 global verbose
249258 if verbose :
250259 print (f'SIZE: { N } Best tuning config: ({ rms_kernel .best_config } )' )
@@ -293,13 +302,14 @@ def main():
293302 if args .no_benchmark :
294303 x = torch .randn (args .M_start , args .N_start , device = 'cuda' )
295304 y = torch .zeros_like (x , device = 'cuda' )
305+ rsigma = torch .empty ((args .M_start , ), device = 'cuda' , dtype = torch .float32 )
296306 n_rows , n_cols = x .shape
297307 MAX_FUSED_SIZE = 65536 // x .element_size ()
298308 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
299309 USE_BLOCKED = n_cols > blk_size
300310 NUM_PRGMS = min (n_rows , get_num_sms ())
301311 g = torch .ones ((1 , args .N_start ), device = 'cuda' )
302- triton_rmsnorm (x , y , g , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
312+ triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
303313 else :
304314 verbose = args .v
305315 run_benchmark (args )
0 commit comments