Skip to content

Commit 2a90b5b

Browse files
tuned few configs by hand (#733)
1 parent 845d46e commit 2a90b5b

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ class RMSNorm(torch.autograd.Function):
297297
def forward(ctx, x, g, y, rsigma, dx, dg, dg_tmp, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED,
298298
NUM_PRGMS, epsilon=1e-6):
299299
# heuristics for number of warps
300-
num_warps = min(max(blk_size // 256, 1), 8)
300+
# num_warps = min(max(blk_size // 256, 1), 8)
301+
num_warps = 8
301302
grid = lambda meta: (NUM_PRGMS, )
302303
rms_kernel[grid](y, x, g, rsigma, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, ZERO_CENTERED_GAMMA,
303304
blk_size, USE_BLOCKED, NUM_PRGMS)
@@ -336,8 +337,8 @@ def backward(ctx, grad_output):
336337

337338
# grid_reduce = lambda meta: (triton.cdiv(n_cols, blk_size), )
338339
grid_reduce = lambda meta: [triton.cdiv(n_cols, meta['BLOCK_SIZE_N'])]
339-
_rmsnorm_bwd_dg_reduce[grid_reduce](dg_tmp, dg, dg_tmp.stride(0), n_rows, n_cols, BLOCK_SIZE_M=32,
340-
BLOCK_SIZE_N=128)
340+
_rmsnorm_bwd_dg_reduce[grid_reduce](dg_tmp, dg, dg_tmp.stride(0), n_rows, n_cols, BLOCK_SIZE_M=128,
341+
BLOCK_SIZE_N=64)
341342

342343
return dx, dg, None, None, None, None, None, None, None, None, None, None, None
343344

@@ -527,8 +528,9 @@ def benchmark(M, N, provider, model=None):
527528
dg = torch.empty((1, N), device='cuda', dtype=dtype, requires_grad=False)
528529
dg_tmp = torch.zeros(M, N, device='cuda', dtype=torch.float32, requires_grad=False)
529530
n_rows, n_cols = x.shape
530-
MAX_FUSED_SIZE = 65536 // x.element_size()
531-
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
531+
# MAX_FUSED_SIZE = 65536 // x.element_size()
532+
# blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
533+
blk_size = 1024
532534
USE_BLOCKED = n_cols > blk_size
533535
NUM_PRGMS = min(n_rows, get_num_sms())
534536
stream = torch.cuda.Stream()

0 commit comments

Comments
 (0)