@@ -46,12 +46,14 @@ def get_autotune_config():
4646@triton .autotune (configs = get_autotune_config (), key = ['n_rows' , 'n_cols' ], use_cuda_graph = True )
4747@triton .jit
4848def rms_kernel (output_ptr , input_ptr , g_ptr , rsigma_ptr , input_row_stride , output_row_stride , n_rows , n_cols , epsilon ,
49- BLOCK_SIZE : tl .constexpr , USE_BLOCKED : tl .constexpr , NUM_PRGMS : tl .constexpr ):
49+ ZERO_CENTERED_GAMMA : tl .constexpr , BLOCK_SIZE : tl .constexpr , USE_BLOCKED : tl .constexpr ,
50+ NUM_PRGMS : tl .constexpr ):
5051 row_start = tl .program_id (0 )
5152 col_offsets = tl .arange (0 , BLOCK_SIZE )
52- tl .assume (input_row_stride >= 0 )
53- tl .assume (output_row_stride >= 0 )
54- tl .assume (row_start >= 0 )
53+ # as older version Triton doesn't support tl.assume and BUFF OPS, comment out for now
54+ # tl.assume(input_row_stride >= 0)
55+ # tl.assume(output_row_stride >= 0)
56+ # tl.assume(row_start >= 0)
5557
5658 if USE_BLOCKED :
5759
@@ -93,6 +95,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
9395 x = tl .load (input_ptrs ).to (tl .float32 )
9496 g_ptrs = g_ptr + cols
9597 g = tl .load (g_ptrs ).to (tl .float32 )
98+ if (ZERO_CENTERED_GAMMA ):
99+ g += 1
96100 rms_norm = x * norm_factor * g
97101 output_ptrs = row_output_ptr + cols
98102 tl .store (output_ptrs , rms_norm .to (output_ptr .type .element_ty ))
@@ -104,6 +108,8 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
104108 x = tl .load (input_ptrs , mask = mask , other = 0.0 , cache_modifier = ".cg" ).to (tl .float32 )
105109 g_ptrs = g_ptr + cols
106110 g = tl .load (g_ptrs , mask = mask , other = 0.0 ).to (tl .float32 )
111+ if (ZERO_CENTERED_GAMMA ):
112+ g += 1
107113 rms_norm = x * norm_factor * g
108114 output_ptrs = row_output_ptr + cols
109115 tl .store (output_ptrs , rms_norm .to (output_ptr .type .element_ty ), mask = mask )
@@ -123,29 +129,36 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
123129 rsigma_output_ptr = rsigma_ptr + row_idx
124130 tl .store (rsigma_output_ptr , norm_factor )
125131
132+ if (ZERO_CENTERED_GAMMA ):
133+ g += 1
126134 rms_norm = row * norm_factor * g
127135
128136 output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets
129137 output_ptrs = tl .multiple_of (output_ptrs , (16 , ))
130138 tl .store (output_ptrs , rms_norm .to (output_ptr .type .element_ty ), mask = mask )
131139
132140
133- def triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS , epsilon = 1e-6 ):
141+ def triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED , NUM_PRGMS ,
142+ epsilon = 1e-6 ):
134143 grid = lambda meta : (NUM_PRGMS , )
135- rms_kernel [grid ](y , x , g , rsigma , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , blk_size , USE_BLOCKED ,
136- NUM_PRGMS )
144+ rms_kernel [grid ](y , x , g , rsigma , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , ZERO_CENTERED_GAMMA , blk_size ,
145+ USE_BLOCKED , NUM_PRGMS )
137146
138147 return y , rsigma
139148
140149
141- def torch_rmsnorm (x , g , epsilon = 1e-6 ):
150+ def torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA , epsilon = 1e-6 ):
142151 M , N = x .shape
143152 rms = torch .sqrt (torch .sum (x * x , dim = - 1 ) * 1 / N )
144153 rsigma = 1.0 / rms
154+ if (ZERO_CENTERED_GAMMA ):
155+ g += 1
145156 rms_norm = x * rsigma .unsqueeze (1 ) * g
157+ rms_norm = rms_norm .to (x .dtype )
146158 return rms_norm , rsigma
147159
148160
161+ @pytest .mark .parametrize ('ZERO_CENTERED_GAMMA' , [True , False ])
149162@pytest .mark .parametrize ('M, N' , [
150163 (1 , 4 ),
151164 (2 , 10 ),
@@ -156,20 +169,23 @@ def torch_rmsnorm(x, g, epsilon=1e-6):
156169 (3 , 65536 ),
157170 (873 , 1245 ),
158171])
159- def test_rmsnorm (M , N ):
172+ def test_rmsnorm (M , N , ZERO_CENTERED_GAMMA ):
160173 torch .manual_seed (0 )
161174 x = torch .randn (M , N , device = 'cuda' )
162175 y = torch .zeros_like (x , device = 'cuda' )
163176 rsigma = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
177+
164178 n_rows , n_cols = x .shape
165179 MAX_FUSED_SIZE = 65536 // x .element_size ()
166180 blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
167181 USE_BLOCKED = n_cols > blk_size
168182 NUM_PRGMS = min (n_rows , get_num_sms ())
169183 g = torch .ones ((1 , N ), device = 'cuda' )
170- y_triton , rsigma_triton = triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
171184
172- y_torch , rsigma_torch = torch_rmsnorm (x , g )
185+ y_triton , rsigma_triton = triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size ,
186+ USE_BLOCKED , NUM_PRGMS )
187+
188+ y_torch , rsigma_torch = torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA )
173189
174190 assert torch .allclose (y_triton , y_torch ), (y_triton , y_torch )
175191 assert torch .allclose (rsigma_triton , rsigma_torch ), (rsigma_triton , rsigma_torch )
@@ -249,11 +265,12 @@ def benchmark(M, N, provider, model=None):
249265 stream = torch .cuda .Stream ()
250266 torch .cuda .set_stream (stream )
251267 g = torch .ones ((1 , N ), device = 'cuda' )
268+ ZERO_CENTERED_GAMMA = False
252269 if provider == 'torch' :
253- ms = triton .testing .do_bench (lambda : torch_rmsnorm (x , g ))
270+ ms = triton .testing .do_bench (lambda : torch_rmsnorm (x , g , ZERO_CENTERED_GAMMA ))
254271 if provider == 'triton' :
255- ms = triton .testing .do_bench (
256- lambda : triton_rmsnorm ( x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS ))
272+ ms = triton .testing .do_bench (lambda : triton_rmsnorm ( x , y , g , rsigma , n_rows , n_cols , ZERO_CENTERED_GAMMA ,
273+ blk_size , USE_BLOCKED , NUM_PRGMS ))
257274 global verbose
258275 if verbose :
259276 print (f'SIZE: { N } Best tuning config: ({ rms_kernel .best_config } )' )
@@ -309,7 +326,8 @@ def main():
309326 USE_BLOCKED = n_cols > blk_size
310327 NUM_PRGMS = min (n_rows , get_num_sms ())
311328 g = torch .ones ((1 , args .N_start ), device = 'cuda' )
312- triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , blk_size , USE_BLOCKED , NUM_PRGMS )
329+ ZERO_CENTERED_GAMMA = True
330+ triton_rmsnorm (x , y , g , rsigma , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED , NUM_PRGMS )
313331 else :
314332 verbose = args .v
315333 run_benchmark (args )
0 commit comments