Skip to content

Commit 44313cf

Browse files
Optimize RMSNorm backward pass (#769)
Applied optimizations: * Partial ∂G reduction in non-blocked version of 1st kernel, as done by Liger Kernel with ∂G_tmp of shape (number of CUs, N). * Usage of `cache_modifier=".cg"` in `tl.load`'s of 2nd reduction kernel, as suggested by Xiaohu. * Minor: Do not allocate ∂G_tmp and do not launch 2nd reduction kernel if M = 1 since there's no need for reduction. Other changes: * Remove `n_cols`, `n_rows`, `blk_size`, `USE_BLOCKED` and `NUM_PRGMS` arguments from `torch.autograd.Function` forward wrapper. `n_cols` and `n_rows` can be easily obtained from input tensor. The logic to compute the other arguments were duplicated all over the place. * Change PyTorch forward reference implementation to use `torch.mean` and `torch.rsqrt`. * Add more test shapes, from Transformer Engine tests and real world use cases suggested by Transformer Engine team. * Improve command line interface: add choices for `--dtype` and set `--no_benchmark` as Boolean. * Fix standalone kernel launcher: a segmentation fault was occurring when trying to run the kernel in isolation (out of unit test or benchmark contexts).
1 parent 32a8fdd commit 44313cf

File tree

1 file changed

+97
-59
lines changed

1 file changed

+97
-59
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 97 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ def get_num_sms():
2424
return num_sms
2525

2626

27+
def num_programs(x):
28+
return min(x.shape[0], get_num_sms())
29+
30+
31+
def block_size(x):
32+
return min(65536 // x.element_size(), triton.next_power_of_2(x.shape[1]))
33+
34+
35+
def use_blocked(x):
36+
return x.shape[1] > block_size(x)
37+
38+
39+
def dg_tmp_rows(x):
40+
return x.shape[0] if use_blocked(x) else num_programs(x)
41+
42+
2743
def get_cuda_autotune_config():
2844
return [
2945
triton.Config({}, num_warps=4, num_stages=1),
@@ -245,11 +261,12 @@ def rms_bwd_kernel(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr
245261

246262
else:
247263
mask = col_offsets < n_cols
264+
dg_col_redux = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
265+
248266
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2):
249267
input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets
250268
grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets
251269
dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets
252-
dg_ptrs = dg_ptr + row_idx * input_row_stride + col_offsets
253270

254271
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
255272
grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, ))
@@ -269,7 +286,9 @@ def rms_bwd_kernel(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr
269286
tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask)
270287

271288
dg = grad_output * x * norm_factor
272-
tl.store(dg_ptrs, dg.to(tl.float32), mask=mask)
289+
dg_col_redux += dg.to(tl.float32)
290+
291+
tl.store(dg_ptr + tl.program_id(0) * input_row_stride + col_offsets, dg_col_redux, mask=mask)
273292

274293

275294
@triton.jit
@@ -285,7 +304,7 @@ def _rmsnorm_bwd_dg_reduce(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols,
285304
rows = i + tl.arange(0, BLOCK_SIZE_M)
286305
mask = (rows[:, None] < n_rows) & (cols[None, :] < n_cols)
287306
offs = rows[:, None] * n_cols + cols[None, :]
288-
acc += tl.load(dg_in_ptr + offs, mask=mask, other=0.).to(tl.float32)
307+
acc += tl.load(dg_in_ptr + offs, mask=mask, other=0., cache_modifier=".cg").to(tl.float32)
289308

290309
sum_dg = tl.sum(acc, axis=0)
291310
tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols)
@@ -294,10 +313,13 @@ def _rmsnorm_bwd_dg_reduce(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols,
294313
class RMSNorm(torch.autograd.Function):
295314

296315
@staticmethod
297-
def forward(ctx, x, g, y, rsigma, dx, dg, dg_tmp, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED,
298-
NUM_PRGMS, epsilon=1e-6):
299-
# heuristics for number of warps
300-
# num_warps = min(max(blk_size // 256, 1), 8)
316+
def forward(ctx, x, g, y, rsigma, dx, dg, dg_tmp, ZERO_CENTERED_GAMMA, epsilon=1e-6):
317+
n_rows, n_cols = x.shape
318+
blk_size = block_size(x)
319+
USE_BLOCKED = use_blocked(x)
320+
NUM_PRGMS = num_programs(x)
321+
# heuristics for number of warps:
322+
# num_warps = min(max(blk_size // 256, 1), 8)
301323
num_warps = 8
302324
grid = lambda meta: (NUM_PRGMS, )
303325
rms_kernel[grid](y, x, g, rsigma, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, ZERO_CENTERED_GAMMA,
@@ -330,17 +352,19 @@ def backward(ctx, grad_output):
330352
blk_size = ctx.blk_size
331353
USE_BLOCKED = ctx.USE_BLOCKED
332354
NUM_PRGMS = ctx.NUM_PRGMS
355+
need_reduction = n_rows > 1
333356

334357
grid_bwd = lambda meta: (NUM_PRGMS, )
335-
rms_bwd_kernel[grid_bwd](grad_output, x, g, rsigma, dx, dg_tmp, x.stride(0), grad_output.stride(0), n_rows,
336-
n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED, NUM_PRGMS, num_warps=ctx.num_warps)
358+
rms_bwd_kernel[grid_bwd](grad_output, x, g, rsigma, dx, dg_tmp if need_reduction else dg, x.stride(0),
359+
grad_output.stride(0), n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED,
360+
NUM_PRGMS, num_warps=ctx.num_warps)
337361

338-
# grid_reduce = lambda meta: (triton.cdiv(n_cols, blk_size), )
339-
grid_reduce = lambda meta: [triton.cdiv(n_cols, meta['BLOCK_SIZE_N'])]
340-
_rmsnorm_bwd_dg_reduce[grid_reduce](dg_tmp, dg, dg_tmp.stride(0), n_rows, n_cols, BLOCK_SIZE_M=128,
341-
BLOCK_SIZE_N=64)
362+
if need_reduction:
363+
grid_reduce = lambda meta: [triton.cdiv(n_cols, meta['BLOCK_SIZE_N'])]
364+
_rmsnorm_bwd_dg_reduce[grid_reduce](dg_tmp, dg, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1],
365+
BLOCK_SIZE_M=128, BLOCK_SIZE_N=64)
342366

343-
return dx, dg, None, None, None, None, None, None, None, None, None, None, None
367+
return dx, dg, None, None, None, None, None, None, None
344368

345369

346370
rmsnorm = RMSNorm.apply
@@ -351,8 +375,8 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
351375
# cast to float32 as the triton kernel
352376
x_f32 = x.float()
353377
g_f32 = g.float()
354-
rms = torch.sqrt(torch.sum(x_f32 * x_f32, dim=-1) * 1 / N)
355-
rsigma = 1.0 / rms
378+
mean_sq_x = torch.mean(x_f32 * x_f32, dim=-1)
379+
rsigma = torch.rsqrt(mean_sq_x + epsilon)
356380
if (ZERO_CENTERED_GAMMA):
357381
g_f32 = g_f32 + 1
358382
rms_norm_f32 = x_f32 * rsigma.unsqueeze(1) * g_f32
@@ -363,11 +387,14 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
363387
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
364388

365389

366-
#@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"])
367-
#@pytest.mark.parametrize("out_dtype_str", ["fp32", "fp16", "bf16"])
390+
# FIXME: Some `fp32` test cases are failing in backward pass.
391+
# There are some fails related to `dx`, but the majority is related to `dg`.
392+
# @pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"])
393+
# @pytest.mark.parametrize("out_dtype_str", ["fp32", "fp16", "bf16"])
368394
@pytest.mark.parametrize("in_dtype_str", ["fp16", "bf16"])
369395
@pytest.mark.parametrize("out_dtype_str", ["fp16", "bf16"])
370396
@pytest.mark.parametrize('ZERO_CENTERED_GAMMA', [True, False])
397+
# yapf: disable
371398
@pytest.mark.parametrize('M, N', [
372399
(1, 4),
373400
(2, 10),
@@ -376,7 +403,20 @@ def torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilo
376403
(1, 31744),
377404
(8192, 65536),
378405
(873, 1245),
406+
# Shapes suggested by TE team:
407+
(4096, 5120),
408+
(8192, 8192),
409+
# TE UT shapes:
410+
(2048, 4096),
411+
(768, 2048),
412+
(256, 1024),
413+
(128, 768),
414+
(64, 512),
415+
(173, 409),
416+
(71, 3571),
417+
(29, 17389),
379418
])
419+
# yapf: enable
380420
def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
381421
in_dtype = arg_to_torch_dtype[in_dtype_str]
382422
out_dtype = arg_to_torch_dtype[out_dtype_str]
@@ -389,16 +429,9 @@ def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
389429

390430
dx = torch.empty_like(x, dtype=in_dtype, requires_grad=False)
391431
dg = torch.empty_like(g, dtype=in_dtype, requires_grad=False)
392-
dg_tmp = torch.zeros(M, N, device='cuda', dtype=torch.float32, requires_grad=False)
432+
dg_tmp = torch.empty(dg_tmp_rows(x), N, device='cuda', dtype=torch.float32, requires_grad=False) if N > 1 else None
393433

394-
n_rows, n_cols = x.shape
395-
MAX_FUSED_SIZE = 65536 // x.element_size()
396-
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
397-
USE_BLOCKED = n_cols > blk_size
398-
NUM_PRGMS = min(n_rows, get_num_sms())
399-
400-
y_triton = rmsnorm(x, g, y, rsigma, dx, dg, dg_tmp, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED,
401-
NUM_PRGMS)
434+
y_triton = rmsnorm(x, g, y, rsigma, dx, dg, dg_tmp, ZERO_CENTERED_GAMMA)
402435

403436
y_torch, rsigma_torch = torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA, out_dtype)
404437

@@ -438,11 +471,11 @@ def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
438471

439472
dx_b = torch.empty_like(x_triton, dtype=in_dtype, requires_grad=False)
440473
dg_b = torch.empty_like(g_triton, dtype=in_dtype, requires_grad=False)
441-
dg_tmp_b = torch.zeros(M, N, device=x_triton.device, dtype=torch.float32, requires_grad=False)
474+
dg_tmp_b = torch.empty(dg_tmp_rows(x_triton), N, device=x_triton.device, dtype=torch.float32,
475+
requires_grad=False) if N > 1 else None
442476

443477
# Run Triton forward pass to build the graph for backward.
444-
y_triton = rmsnorm(x_triton, g_triton, y_triton_buf, rsigma_triton, dx_b, dg_b, dg_tmp_b, n_rows, n_cols,
445-
ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED, NUM_PRGMS)
478+
y_triton = rmsnorm(x_triton, g_triton, y_triton_buf, rsigma_triton, dx_b, dg_b, dg_tmp_b, ZERO_CENTERED_GAMMA)
446479
y_triton.backward(grad_output, retain_graph=True)
447480
grad_x_triton = x_triton.grad.to(out_dtype)
448481
grad_g_triton = g_triton.grad.to(out_dtype)
@@ -526,22 +559,16 @@ def benchmark(M, N, provider, model=None):
526559
rsigma = torch.empty((M, ), device='cuda', dtype=torch.float32)
527560
dx = torch.empty(M, N, device='cuda', dtype=dtype, requires_grad=False)
528561
dg = torch.empty((1, N), device='cuda', dtype=dtype, requires_grad=False)
529-
dg_tmp = torch.zeros(M, N, device='cuda', dtype=torch.float32, requires_grad=False)
530-
n_rows, n_cols = x.shape
531-
# MAX_FUSED_SIZE = 65536 // x.element_size()
532-
# blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
533-
blk_size = 1024
534-
USE_BLOCKED = n_cols > blk_size
535-
NUM_PRGMS = min(n_rows, get_num_sms())
562+
dg_tmp = torch.empty(dg_tmp_rows(x), N, device='cuda', dtype=torch.float32,
563+
requires_grad=False) if N > 1 else None
536564
stream = torch.cuda.Stream()
537565
torch.cuda.set_stream(stream)
538566
g = torch.ones((1, N), device='cuda')
539567
ZERO_CENTERED_GAMMA = False
540568

541569
def rms_fwd():
542570
if provider == 'triton':
543-
return rmsnorm(x, g, y, rsigma, dx, dg, dg_tmp, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size,
544-
USE_BLOCKED, NUM_PRGMS)
571+
return rmsnorm(x, g, y, rsigma, dx, dg, dg_tmp, ZERO_CENTERED_GAMMA)
545572
if provider == 'torch':
546573
return torch_rmsnorm_fwd(x, g, ZERO_CENTERED_GAMMA)
547574

@@ -555,20 +582,19 @@ def rms_fwd():
555582
y_ = torch.zeros_like(x_, dtype=dtype)
556583
rsigma_ = torch.empty((M, ), device='cuda', dtype=torch.float32)
557584
dx_ = torch.empty_like(x_, dtype=dtype)
558-
dg_tmp_ = torch.empty_like(x_, dtype=torch.float32)
559585
dg_ = torch.empty_like(g_, dtype=dtype)
586+
dg_tmp_ = torch.empty(dg_tmp_rows(x_), N, device='cuda', dtype=torch.float32) if N > 1 else None
560587
grad_out = torch.randn_like(y_)
561588

562-
y_out = rmsnorm(x_, g_, y_, rsigma_, dx_, dg_, dg_tmp_, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size,
563-
USE_BLOCKED, NUM_PRGMS)
589+
y_out = rmsnorm(x_, g_, y_, rsigma_, dx_, dg_, dg_tmp_, ZERO_CENTERED_GAMMA)
564590

565591
ms = triton.testing.do_bench(lambda: y_out.backward(grad_out, retain_graph=True), grad_to_none=[x_, g_])
566592
else:
567593
raise ValueError(f"mode {mode} is not supported!")
568594

569595
global verbose
570596
if verbose:
571-
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})')
597+
print(f'SIZE: {N} Best forward tuning config: ({rms_kernel.best_config})')
572598
print(f'time: {ms}')
573599
gbps = lambda ms_val: 2 * x.nelement() * x.element_size() * 1e-9 / (ms_val * 1e-3)
574600
return gbps(ms)
@@ -599,8 +625,8 @@ def parse_args():
599625
parser.add_argument('-Ns', "--N_step", default="1024", type=int)
600626
parser.add_argument('-Ne', "--N_end", default="32768", type=int)
601627

602-
parser.add_argument('-d', "--dtype", default="fp16")
603-
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool)
628+
parser.add_argument('-d', "--dtype", type=str, choices=list(arg_to_torch_dtype.keys()), default="fp16")
629+
parser.add_argument('-nb', "--no_benchmark", action="store_true", default=False)
604630
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
605631
parser.add_argument("--mode", type=str, choices=["fwd", "bwd"], default="fwd",
606632
help="Benchmark mode: forward only, backward only, or both.")
@@ -611,21 +637,33 @@ def parse_args():
611637
def main():
612638
args = parse_args()
613639
global verbose
640+
614641
if args.no_benchmark:
615-
x = torch.randn(args.M_start, args.N_start, device='cuda', dtype=args.dtype)
616-
y = torch.zeros_like(x, device='cuda')
617-
rsigma = torch.empty((args.M_start, ), device='cuda', dtype=torch.float32)
618-
dx = torch.empty(args.M_start, args.N_start, device='cuda', dtype=args.dtype, requires_grad=False)
619-
dg = torch.empty((1, args.N_start), device='cuda', dtype=args.dtype, requires_grad=False)
620-
dg_tmp = torch.zeros(args.M_start, args.N_start, device='cuda', dtype=torch.float32, requires_grad=False)
621-
n_rows, n_cols = x.shape
622-
MAX_FUSED_SIZE = 65536 // x.element_size()
623-
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
624-
USE_BLOCKED = n_cols > blk_size
625-
NUM_PRGMS = min(n_rows, get_num_sms())
626-
g = torch.ones((1, args.N_start), device='cuda', dtype=args.dtype)
627-
ZERO_CENTERED_GAMMA = True
628-
rmsnorm(x, y, g, rsigma, dx, dg, dg_tmp, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size, USE_BLOCKED, NUM_PRGMS)
642+
in_dtype_str = out_dtype_str = args.dtype
643+
M, N = args.M_start, args.N_start
644+
ZERO_CENTERED_GAMMA = False
645+
646+
# Run kernel as done in test:
647+
in_dtype = arg_to_torch_dtype[in_dtype_str]
648+
out_dtype = arg_to_torch_dtype[out_dtype_str]
649+
torch.manual_seed(0)
650+
651+
x = torch.randn(M, N, device='cuda', dtype=in_dtype, requires_grad=True)
652+
g = torch.ones((1, N), device='cuda', dtype=in_dtype, requires_grad=True)
653+
y = torch.zeros_like(x, device='cuda', dtype=out_dtype)
654+
rsigma = torch.empty((M, ), device=x.device, dtype=torch.float32)
655+
656+
dx = torch.empty_like(x, dtype=in_dtype, requires_grad=False)
657+
dg = torch.empty_like(g, dtype=in_dtype, requires_grad=False)
658+
dg_tmp = torch.empty(dg_tmp_rows(x), N, device='cuda', dtype=torch.float32,
659+
requires_grad=False) if N > 1 else None
660+
661+
y_triton = rmsnorm(x, g, y, rsigma, dx, dg, dg_tmp, ZERO_CENTERED_GAMMA)
662+
663+
if args.mode == "bwd":
664+
grad_output = torch.randn_like(y_triton)
665+
y_triton.backward(grad_output, retain_graph=True)
666+
629667
else:
630668
verbose = args.v
631669
run_benchmark(args)

0 commit comments

Comments
 (0)