@@ -24,6 +24,22 @@ def get_num_sms():
2424 return num_sms
2525
2626
27+ def num_programs (x ):
28+ return min (x .shape [0 ], get_num_sms ())
29+
30+
31+ def block_size (x ):
32+ return min (65536 // x .element_size (), triton .next_power_of_2 (x .shape [1 ]))
33+
34+
35+ def use_blocked (x ):
36+ return x .shape [1 ] > block_size (x )
37+
38+
39+ def dg_tmp_rows (x ):
40+ return x .shape [0 ] if use_blocked (x ) else num_programs (x )
41+
42+
2743def get_cuda_autotune_config ():
2844 return [
2945 triton .Config ({}, num_warps = 4 , num_stages = 1 ),
@@ -245,11 +261,12 @@ def rms_bwd_kernel(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr
245261
246262 else :
247263 mask = col_offsets < n_cols
264+ dg_col_redux = tl .zeros ((BLOCK_SIZE , ), dtype = tl .float32 )
265+
248266 for row_idx in tl .range (row_start , n_rows , NUM_PRGMS , num_stages = 2 ):
249267 input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
250268 grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets
251269 dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets
252- dg_ptrs = dg_ptr + row_idx * input_row_stride + col_offsets
253270
254271 input_ptrs = tl .multiple_of (input_ptrs , (16 , ))
255272 grad_output_ptrs = tl .multiple_of (grad_output_ptrs , (16 , ))
@@ -269,7 +286,9 @@ def rms_bwd_kernel(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr
269286 tl .store (dx_ptrs , grad_input .to (dx_ptr .type .element_ty ), mask = mask )
270287
271288 dg = grad_output * x * norm_factor
272- tl .store (dg_ptrs , dg .to (tl .float32 ), mask = mask )
289+ dg_col_redux += dg .to (tl .float32 )
290+
291+ tl .store (dg_ptr + tl .program_id (0 ) * input_row_stride + col_offsets , dg_col_redux , mask = mask )
273292
274293
275294@triton .jit
@@ -285,7 +304,7 @@ def _rmsnorm_bwd_dg_reduce(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols,
285304 rows = i + tl .arange (0 , BLOCK_SIZE_M )
286305 mask = (rows [:, None ] < n_rows ) & (cols [None , :] < n_cols )
287306 offs = rows [:, None ] * n_cols + cols [None , :]
288- acc += tl .load (dg_in_ptr + offs , mask = mask , other = 0. ).to (tl .float32 )
307+ acc += tl .load (dg_in_ptr + offs , mask = mask , other = 0. , cache_modifier = ".cg" ).to (tl .float32 )
289308
290309 sum_dg = tl .sum (acc , axis = 0 )
291310 tl .store (dg_out_ptr + cols , sum_dg .to (dg_out_ptr .type .element_ty ), mask = cols < n_cols )
@@ -294,10 +313,13 @@ def _rmsnorm_bwd_dg_reduce(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols,
294313class RMSNorm (torch .autograd .Function ):
295314
296315 @staticmethod
297- def forward (ctx , x , g , y , rsigma , dx , dg , dg_tmp , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED ,
298- NUM_PRGMS , epsilon = 1e-6 ):
299- # heuristics for number of warps
300- # num_warps = min(max(blk_size // 256, 1), 8)
316+ def forward (ctx , x , g , y , rsigma , dx , dg , dg_tmp , ZERO_CENTERED_GAMMA , epsilon = 1e-6 ):
317+ n_rows , n_cols = x .shape
318+ blk_size = block_size (x )
319+ USE_BLOCKED = use_blocked (x )
320+ NUM_PRGMS = num_programs (x )
321+ # heuristics for number of warps:
322+ # num_warps = min(max(blk_size // 256, 1), 8)
301323 num_warps = 8
302324 grid = lambda meta : (NUM_PRGMS , )
303325 rms_kernel [grid ](y , x , g , rsigma , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , ZERO_CENTERED_GAMMA ,
@@ -330,17 +352,19 @@ def backward(ctx, grad_output):
330352 blk_size = ctx .blk_size
331353 USE_BLOCKED = ctx .USE_BLOCKED
332354 NUM_PRGMS = ctx .NUM_PRGMS
355+ need_reduction = n_rows > 1
333356
334357 grid_bwd = lambda meta : (NUM_PRGMS , )
335- rms_bwd_kernel [grid_bwd ](grad_output , x , g , rsigma , dx , dg_tmp , x .stride (0 ), grad_output .stride (0 ), n_rows ,
336- n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED , NUM_PRGMS , num_warps = ctx .num_warps )
358+ rms_bwd_kernel [grid_bwd ](grad_output , x , g , rsigma , dx , dg_tmp if need_reduction else dg , x .stride (0 ),
359+ grad_output .stride (0 ), n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED ,
360+ NUM_PRGMS , num_warps = ctx .num_warps )
337361
338- # grid_reduce = lambda meta: (triton.cdiv(n_cols, blk_size), )
339- grid_reduce = lambda meta : [triton .cdiv (n_cols , meta ['BLOCK_SIZE_N' ])]
340- _rmsnorm_bwd_dg_reduce [grid_reduce ](dg_tmp , dg , dg_tmp .stride (0 ), n_rows , n_cols , BLOCK_SIZE_M = 128 ,
341- BLOCK_SIZE_N = 64 )
362+ if need_reduction :
363+ grid_reduce = lambda meta : [triton .cdiv (n_cols , meta ['BLOCK_SIZE_N' ])]
364+ _rmsnorm_bwd_dg_reduce [grid_reduce ](dg_tmp , dg , dg_tmp .stride (0 ), dg_tmp . shape [ 0 ], dg_tmp . shape [ 1 ] ,
365+ BLOCK_SIZE_M = 128 , BLOCK_SIZE_N = 64 )
342366
343- return dx , dg , None , None , None , None , None , None , None , None , None , None , None
367+ return dx , dg , None , None , None , None , None , None , None
344368
345369
346370rmsnorm = RMSNorm .apply
@@ -351,8 +375,8 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
351375 # cast to float32 as the triton kernel
352376 x_f32 = x .float ()
353377 g_f32 = g .float ()
354- rms = torch .sqrt ( torch . sum ( x_f32 * x_f32 , dim = - 1 ) * 1 / N )
355- rsigma = 1.0 / rms
378+ mean_sq_x = torch .mean ( x_f32 * x_f32 , dim = - 1 )
379+ rsigma = torch . rsqrt ( mean_sq_x + epsilon )
356380 if (ZERO_CENTERED_GAMMA ):
357381 g_f32 = g_f32 + 1
358382 rms_norm_f32 = x_f32 * rsigma .unsqueeze (1 ) * g_f32
@@ -363,11 +387,14 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
363387arg_to_torch_dtype = {'fp16' : torch .float16 , 'bf16' : torch .bfloat16 , 'fp32' : torch .float32 }
364388
365389
366- #@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"])
367- #@pytest.mark.parametrize("out_dtype_str", ["fp32", "fp16", "bf16"])
390+ # FIXME: Some `fp32` test cases are failing in backward pass.
391+ # There are some fails related to `dx`, but the majority is related to `dg`.
392+ # @pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"])
393+ # @pytest.mark.parametrize("out_dtype_str", ["fp32", "fp16", "bf16"])
368394@pytest .mark .parametrize ("in_dtype_str" , ["fp16" , "bf16" ])
369395@pytest .mark .parametrize ("out_dtype_str" , ["fp16" , "bf16" ])
370396@pytest .mark .parametrize ('ZERO_CENTERED_GAMMA' , [True , False ])
397+ # yapf: disable
371398@pytest .mark .parametrize ('M, N' , [
372399 (1 , 4 ),
373400 (2 , 10 ),
@@ -376,7 +403,20 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
376403 (1 , 31744 ),
377404 (8192 , 65536 ),
378405 (873 , 1245 ),
406+ # Shapes suggested by TE team:
407+ (4096 , 5120 ),
408+ (8192 , 8192 ),
409+ # TE UT shapes:
410+ (2048 , 4096 ),
411+ (768 , 2048 ),
412+ (256 , 1024 ),
413+ (128 , 768 ),
414+ (64 , 512 ),
415+ (173 , 409 ),
416+ (71 , 3571 ),
417+ (29 , 17389 ),
379418])
419+ # yapf: enable
380420def test_rmsnorm (M , N , ZERO_CENTERED_GAMMA , in_dtype_str , out_dtype_str ):
381421 in_dtype = arg_to_torch_dtype [in_dtype_str ]
382422 out_dtype = arg_to_torch_dtype [out_dtype_str ]
@@ -389,16 +429,9 @@ def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
389429
390430 dx = torch .empty_like (x , dtype = in_dtype , requires_grad = False )
391431 dg = torch .empty_like (g , dtype = in_dtype , requires_grad = False )
392- dg_tmp = torch .zeros ( M , N , device = 'cuda' , dtype = torch .float32 , requires_grad = False )
432+ dg_tmp = torch .empty ( dg_tmp_rows ( x ) , N , device = 'cuda' , dtype = torch .float32 , requires_grad = False ) if N > 1 else None
393433
394- n_rows , n_cols = x .shape
395- MAX_FUSED_SIZE = 65536 // x .element_size ()
396- blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
397- USE_BLOCKED = n_cols > blk_size
398- NUM_PRGMS = min (n_rows , get_num_sms ())
399-
400- y_triton = rmsnorm (x , g , y , rsigma , dx , dg , dg_tmp , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED ,
401- NUM_PRGMS )
434+ y_triton = rmsnorm (x , g , y , rsigma , dx , dg , dg_tmp , ZERO_CENTERED_GAMMA )
402435
403436 y_torch , rsigma_torch = torch_rmsnorm_fwd (x , g , ZERO_CENTERED_GAMMA , out_dtype )
404437
@@ -438,11 +471,11 @@ def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
438471
439472 dx_b = torch .empty_like (x_triton , dtype = in_dtype , requires_grad = False )
440473 dg_b = torch .empty_like (g_triton , dtype = in_dtype , requires_grad = False )
441- dg_tmp_b = torch .zeros (M , N , device = x_triton .device , dtype = torch .float32 , requires_grad = False )
474+ dg_tmp_b = torch .empty (dg_tmp_rows (x_triton ), N , device = x_triton .device , dtype = torch .float32 ,
475+ requires_grad = False ) if N > 1 else None
442476
443477 # Run Triton forward pass to build the graph for backward.
444- y_triton = rmsnorm (x_triton , g_triton , y_triton_buf , rsigma_triton , dx_b , dg_b , dg_tmp_b , n_rows , n_cols ,
445- ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED , NUM_PRGMS )
478+ y_triton = rmsnorm (x_triton , g_triton , y_triton_buf , rsigma_triton , dx_b , dg_b , dg_tmp_b , ZERO_CENTERED_GAMMA )
446479 y_triton .backward (grad_output , retain_graph = True )
447480 grad_x_triton = x_triton .grad .to (out_dtype )
448481 grad_g_triton = g_triton .grad .to (out_dtype )
@@ -526,22 +559,16 @@ def benchmark(M, N, provider, model=None):
526559 rsigma = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
527560 dx = torch .empty (M , N , device = 'cuda' , dtype = dtype , requires_grad = False )
528561 dg = torch .empty ((1 , N ), device = 'cuda' , dtype = dtype , requires_grad = False )
529- dg_tmp = torch .zeros (M , N , device = 'cuda' , dtype = torch .float32 , requires_grad = False )
530- n_rows , n_cols = x .shape
531- # MAX_FUSED_SIZE = 65536 // x.element_size()
532- # blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
533- blk_size = 1024
534- USE_BLOCKED = n_cols > blk_size
535- NUM_PRGMS = min (n_rows , get_num_sms ())
562+ dg_tmp = torch .empty (dg_tmp_rows (x ), N , device = 'cuda' , dtype = torch .float32 ,
563+ requires_grad = False ) if N > 1 else None
536564 stream = torch .cuda .Stream ()
537565 torch .cuda .set_stream (stream )
538566 g = torch .ones ((1 , N ), device = 'cuda' )
539567 ZERO_CENTERED_GAMMA = False
540568
541569 def rms_fwd ():
542570 if provider == 'triton' :
543- return rmsnorm (x , g , y , rsigma , dx , dg , dg_tmp , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size ,
544- USE_BLOCKED , NUM_PRGMS )
571+ return rmsnorm (x , g , y , rsigma , dx , dg , dg_tmp , ZERO_CENTERED_GAMMA )
545572 if provider == 'torch' :
546573 return torch_rmsnorm_fwd (x , g , ZERO_CENTERED_GAMMA )
547574
@@ -555,20 +582,19 @@ def rms_fwd():
555582 y_ = torch .zeros_like (x_ , dtype = dtype )
556583 rsigma_ = torch .empty ((M , ), device = 'cuda' , dtype = torch .float32 )
557584 dx_ = torch .empty_like (x_ , dtype = dtype )
558- dg_tmp_ = torch .empty_like (x_ , dtype = torch .float32 )
559585 dg_ = torch .empty_like (g_ , dtype = dtype )
586+ dg_tmp_ = torch .empty (dg_tmp_rows (x_ ), N , device = 'cuda' , dtype = torch .float32 ) if N > 1 else None
560587 grad_out = torch .randn_like (y_ )
561588
562- y_out = rmsnorm (x_ , g_ , y_ , rsigma_ , dx_ , dg_ , dg_tmp_ , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size ,
563- USE_BLOCKED , NUM_PRGMS )
589+ y_out = rmsnorm (x_ , g_ , y_ , rsigma_ , dx_ , dg_ , dg_tmp_ , ZERO_CENTERED_GAMMA )
564590
565591 ms = triton .testing .do_bench (lambda : y_out .backward (grad_out , retain_graph = True ), grad_to_none = [x_ , g_ ])
566592 else :
567593 raise ValueError (f"mode { mode } is not supported!" )
568594
569595 global verbose
570596 if verbose :
571- print (f'SIZE: { N } Best tuning config: ({ rms_kernel .best_config } )' )
597+ print (f'SIZE: { N } Best forward tuning config: ({ rms_kernel .best_config } )' )
572598 print (f'time: { ms } ' )
573599 gbps = lambda ms_val : 2 * x .nelement () * x .element_size () * 1e-9 / (ms_val * 1e-3 )
574600 return gbps (ms )
@@ -599,8 +625,8 @@ def parse_args():
599625 parser .add_argument ('-Ns' , "--N_step" , default = "1024" , type = int )
600626 parser .add_argument ('-Ne' , "--N_end" , default = "32768" , type = int )
601627
602- parser .add_argument ('-d' , "--dtype" , default = "fp16" )
603- parser .add_argument ('-nb' , "--no_benchmark" , default = False , type = bool )
628+ parser .add_argument ('-d' , "--dtype" , type = str , choices = list ( arg_to_torch_dtype . keys ()), default = "fp16" )
629+ parser .add_argument ('-nb' , "--no_benchmark" , action = "store_true" , default = False )
604630 parser .add_argument ("-v" , action = 'store_true' , default = False , help = "Print out the best tuning config" )
605631 parser .add_argument ("--mode" , type = str , choices = ["fwd" , "bwd" ], default = "fwd" ,
606632 help = "Benchmark mode: forward only, backward only, or both." )
@@ -611,21 +637,33 @@ def parse_args():
611637def main ():
612638 args = parse_args ()
613639 global verbose
640+
614641 if args .no_benchmark :
615- x = torch .randn (args .M_start , args .N_start , device = 'cuda' , dtype = args .dtype )
616- y = torch .zeros_like (x , device = 'cuda' )
617- rsigma = torch .empty ((args .M_start , ), device = 'cuda' , dtype = torch .float32 )
618- dx = torch .empty (args .M_start , args .N_start , device = 'cuda' , dtype = args .dtype , requires_grad = False )
619- dg = torch .empty ((1 , args .N_start ), device = 'cuda' , dtype = args .dtype , requires_grad = False )
620- dg_tmp = torch .zeros (args .M_start , args .N_start , device = 'cuda' , dtype = torch .float32 , requires_grad = False )
621- n_rows , n_cols = x .shape
622- MAX_FUSED_SIZE = 65536 // x .element_size ()
623- blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
624- USE_BLOCKED = n_cols > blk_size
625- NUM_PRGMS = min (n_rows , get_num_sms ())
626- g = torch .ones ((1 , args .N_start ), device = 'cuda' , dtype = args .dtype )
627- ZERO_CENTERED_GAMMA = True
628- rmsnorm (x , y , g , rsigma , dx , dg , dg_tmp , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED , NUM_PRGMS )
642+ in_dtype_str = out_dtype_str = args .dtype
643+ M , N = args .M_start , args .N_start
644+ ZERO_CENTERED_GAMMA = False
645+
646+ # Run kernel as done in test:
647+ in_dtype = arg_to_torch_dtype [in_dtype_str ]
648+ out_dtype = arg_to_torch_dtype [out_dtype_str ]
649+ torch .manual_seed (0 )
650+
651+ x = torch .randn (M , N , device = 'cuda' , dtype = in_dtype , requires_grad = True )
652+ g = torch .ones ((1 , N ), device = 'cuda' , dtype = in_dtype , requires_grad = True )
653+ y = torch .zeros_like (x , device = 'cuda' , dtype = out_dtype )
654+ rsigma = torch .empty ((M , ), device = x .device , dtype = torch .float32 )
655+
656+ dx = torch .empty_like (x , dtype = in_dtype , requires_grad = False )
657+ dg = torch .empty_like (g , dtype = in_dtype , requires_grad = False )
658+ dg_tmp = torch .empty (dg_tmp_rows (x ), N , device = 'cuda' , dtype = torch .float32 ,
659+ requires_grad = False ) if N > 1 else None
660+
661+ y_triton = rmsnorm (x , g , y , rsigma , dx , dg , dg_tmp , ZERO_CENTERED_GAMMA )
662+
663+ if args .mode == "bwd" :
664+ grad_output = torch .randn_like (y_triton )
665+ y_triton .backward (grad_output , retain_graph = True )
666+
629667 else :
630668 verbose = args .v
631669 run_benchmark (args )
0 commit comments