@@ -297,7 +297,8 @@ class RMSNorm(torch.autograd.Function):
297297 def forward (ctx , x , g , y , rsigma , dx , dg , dg_tmp , n_rows , n_cols , ZERO_CENTERED_GAMMA , blk_size , USE_BLOCKED ,
298298 NUM_PRGMS , epsilon = 1e-6 ):
299299 # heuristics for number of warps
300- num_warps = min (max (blk_size // 256 , 1 ), 8 )
300+ # num_warps = min(max(blk_size // 256, 1), 8)
301+ num_warps = 8
301302 grid = lambda meta : (NUM_PRGMS , )
302303 rms_kernel [grid ](y , x , g , rsigma , x .stride (0 ), y .stride (0 ), n_rows , n_cols , epsilon , ZERO_CENTERED_GAMMA ,
303304 blk_size , USE_BLOCKED , NUM_PRGMS )
@@ -336,8 +337,8 @@ def backward(ctx, grad_output):
336337
337338 # grid_reduce = lambda meta: (triton.cdiv(n_cols, blk_size), )
338339 grid_reduce = lambda meta : [triton .cdiv (n_cols , meta ['BLOCK_SIZE_N' ])]
339- _rmsnorm_bwd_dg_reduce [grid_reduce ](dg_tmp , dg , dg_tmp .stride (0 ), n_rows , n_cols , BLOCK_SIZE_M = 32 ,
340- BLOCK_SIZE_N = 128 )
340+ _rmsnorm_bwd_dg_reduce [grid_reduce ](dg_tmp , dg , dg_tmp .stride (0 ), n_rows , n_cols , BLOCK_SIZE_M = 128 ,
341+ BLOCK_SIZE_N = 64 )
341342
342343 return dx , dg , None , None , None , None , None , None , None , None , None , None , None
343344
@@ -527,8 +528,9 @@ def benchmark(M, N, provider, model=None):
527528 dg = torch .empty ((1 , N ), device = 'cuda' , dtype = dtype , requires_grad = False )
528529 dg_tmp = torch .zeros (M , N , device = 'cuda' , dtype = torch .float32 , requires_grad = False )
529530 n_rows , n_cols = x .shape
530- MAX_FUSED_SIZE = 65536 // x .element_size ()
531- blk_size = min (MAX_FUSED_SIZE , triton .next_power_of_2 (n_cols ))
531+ # MAX_FUSED_SIZE = 65536 // x.element_size()
532+ # blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
533+ blk_size = 1024
532534 USE_BLOCKED = n_cols > blk_size
533535 NUM_PRGMS = min (n_rows , get_num_sms ())
534536 stream = torch .cuda .Stream ()
0 commit comments