@@ -64,7 +64,10 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
6464
6565 # Accumulate sum of squares
6666 n_cols_blks = tl .cdiv (n_cols , BLOCK_SIZE ) - 1
67- sum_squares : tl .float32 = 0.
67+ # older version of triton doesn't accept below init
68+ # sum_squares: tl.float32 = 0.
69+ # however, with type promoting rule in triton, sum_squares should be always fp32 with below init
70+ sum_squares = 0.
6871 for blk_idx in tl .range (0 , n_cols_blks , num_stages = 2 ):
6972 cols = blk_idx * BLOCK_SIZE + col_offsets
7073 input_ptrs = row_input_ptr + cols
@@ -147,54 +150,71 @@ def triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_siz
147150 return y , rsigma
148151
149152
150- def torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA , epsilon = 1e-6 ):
153+ def torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA , out_dtype = torch . float16 , epsilon = 1e-6 ):
151154 M , N = x .shape
152- rms = torch .sqrt (torch .sum (x * x , dim = - 1 ) * 1 / N )
155+ # cast to float32 as the triton kernel
156+ x_f32 = x .float ()
157+ g_f32 = g .float ()
158+ rms = torch .sqrt (torch .sum (x_f32 * x_f32 , dim = - 1 ) * 1 / N )
153159 rsigma = 1.0 / rms
154160 if (ZERO_CENTERED_GAMMA ):
155- g += 1
156- rms_norm = x * rsigma .unsqueeze (1 ) * g
157- rms_norm = rms_norm .to (x . dtype )
161+ g_f32 += 1
162+ rms_norm_f32 = x_f32 * rsigma .unsqueeze (1 ) * g_f32
163+ rms_norm = rms_norm_f32 .to (out_dtype )
158164 return rms_norm , rsigma
159165
160166
167+ arg_to_torch_dtype = {'fp16' : torch .float16 , 'bf16' : torch .bfloat16 , 'fp32' : torch .float32 }
168+
169+
170+ @pytest .mark .parametrize ("in_dtype_str" , ["fp32" , "fp16" , "bf16" ])
171+ @pytest .mark .parametrize ("out_dtype_str" , ["fp32" , "fp16" , "bf16" ])
161172@pytest .mark .parametrize ('ZERO_CENTERED_GAMMA' , [True , False ])
162173@pytest .mark .parametrize ('M, N' , [
163174 (1 , 4 ),
164175 (2 , 10 ),
165176 (8192 , 4096 ),
166177 (4096 , 8192 ),
167- (1 , 8192 ),
168178 (1 , 31744 ),
169179 (3 , 65536 ),
170180 (873 , 1245 ),
171181])
172- def test_rmsnorm (M , N , ZERO_CENTERED_GAMMA ):
182+ def test_rmsnorm (M , N , ZERO_CENTERED_GAMMA , in_dtype_str , out_dtype_str ):
183+ in_dtype = arg_to_torch_dtype [in_dtype_str ]
184+ out_dtype = arg_to_torch_dtype [out_dtype_str ]
173185 torch .manual_seed (0 )
174- x = torch .randn (M , N , device = 'cuda' )
175- y = torch .zeros_like (x , device = 'cuda' )
186+ x = torch .randn (M , N , device = 'cuda' , dtype = in_dtype )
187+ y = torch .zeros_like (x , device = 'cuda' , dtype = out_dtype )
176188 rsigma = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
177189
178190 n_rows , n_cols = x .shape
179191 MAX_FUSED_SIZE = 65536 // x .element_size ()
180192 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
181193 USE_BLOCKED = n_cols > blk_size
182194 NUM_PRGMS = min (n_rows , get_num_sms ())
183- g = torch .ones ((1 , N ), device = 'cuda' )
195+ g = torch .ones ((1 , N ), device = 'cuda' , dtype = in_dtype )
184196
185197 y_triton , rsigma_triton = triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size ,
186198 USE_BLOCKED , NUM_PRGMS )
187199
188- y_torch , rsigma_torch = torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA )
200+ y_torch , rsigma_torch = torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA , out_dtype )
189201
190- assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
191- assert torch .allclose (rsigma_triton , rsigma_torch ), (rsigma_triton , rsigma_torch )
202+ if out_dtype in (torch .float16 , torch .bfloat16 ):
203+ atol , rtol = 1e-3 , 1e-2
204+ else :
205+ # float32 typically can be tighter
206+ atol , rtol = 1e-5 , 1e-5
192207
208+ assert y_triton .dtype == out_dtype , f"y_triton has dtype={ y_triton .dtype } , expected { out_dtype } "
209+ assert y_torch .dtype == out_dtype , f"y_torch has dtype={ y_torch .dtype } , expected { out_dtype } "
193210
194- #Benchmark
195- arg_to_torch_dtype = {'fp16' : torch .float16 , 'bf16' : torch .bfloat16 , 'fp32' : torch .float32 }
211+ assert torch .allclose (y_triton , y_torch , atol = atol , rtol = rtol ), \
212+ f"Mismatch in 'y' (in={ in_dtype_str } , out={ out_dtype_str } )"
213+ assert torch .allclose (rsigma_triton , rsigma_torch , atol = atol , rtol = rtol ), \
214+ f"Mismatch in 'rsigma' (in={ in_dtype_str } , out={ out_dtype_str } )"
196215
197216
217+ #Benchmark
198218def model_benchmark_configs (args ):
199219 config_file = args .model_configs
200220 configs = get_model_configs (config_path = config_file , model_families = ["llama3" ], model = args .model )
0 commit comments