Skip to content

Commit 27b2774

Browse files
authored
Remerge LayerNorm (#348) (#373)
1 parent 3288b24 commit 27b2774

File tree

4 files changed

+138
-0
lines changed

4 files changed

+138
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@
7575
"examples.fp8_attention",
7676
"fp8_attention_tritonbench",
7777
),
78+
"layer_norm": (
79+
"tritonbench.operators.layer_norm.operator",
80+
"examples.layer_norm",
81+
"layer_norm_fwd",
82+
),
7883
}
7984

8085

examples/layer_norm.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel
11+
def layer_norm_fwd(
12+
x: torch.Tensor,
13+
nomralized_shape: list[int],
14+
weight: torch.Tensor,
15+
bias: torch.Tensor,
16+
eps: float = 1e-5,
17+
) -> torch.Tensor:
18+
m, n = x.size()
19+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
20+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
21+
assert len(nomralized_shape) == 1, (
22+
"Helion layer norm only supports 1D layer norm currently"
23+
)
24+
assert nomralized_shape[0] == n, (
25+
f"normalized shape mismatch {nomralized_shape[0]} != {n}"
26+
)
27+
28+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
29+
30+
for tile_m in hl.tile(m):
31+
acc = x[tile_m, :].to(torch.float32)
32+
33+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
34+
35+
normalized = (acc - mean) * torch.rsqrt(var + eps)
36+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
37+
38+
out[tile_m, :] = acc
39+
return out
40+
41+
42+
def main() -> None:
43+
batch_size = 32
44+
dim = 64
45+
device = "cuda"
46+
47+
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
48+
weight = torch.randn([dim], device=device, dtype=torch.float16)
49+
bias = torch.randn([dim], device=device, dtype=torch.float16)
50+
eps = 1e-4
51+
52+
run_example(
53+
layer_norm_fwd,
54+
torch.nn.functional.layer_norm,
55+
(x, [dim], weight, bias, eps),
56+
kernel_name="helion",
57+
baseline_name="torch",
58+
rtol=1e-3,
59+
atol=1e-3,
60+
)
61+
62+
63+
if __name__ == "__main__":
64+
main()

test/test_examples.expected

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,59 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
883883
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
884884
return out
885885

886+
--- assertExpectedJournal(TestExamples.test_layernorm)
887+
from __future__ import annotations
888+
889+
import torch
890+
import triton
891+
import triton.language as tl
892+
from torch._inductor.runtime.triton_compat import libdevice
893+
from helion.runtime import default_launcher as _default_launcher
894+
895+
@triton.jit
896+
def _layer_norm_fwd_kernel(bias, x, weight, out, bias_size_0, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
897+
pid_0 = tl.program_id(0)
898+
offset_0 = pid_0 * _BLOCK_SIZE_0
899+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
900+
mask_0 = indices_0 < m
901+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
902+
mask_1 = indices_1 < bias_size_0
903+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
904+
v_0 = load.to(tl.float32)
905+
var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
906+
v_1 = var_mean_extra / bias_size_0.to(tl.float32)
907+
_mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, 0)
908+
v_2 = v_0 - _mask_to_1
909+
v_3 = v_2 * v_2
910+
var_mean_extra_2 = tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1])
911+
v_4 = var_mean_extra_2 / bias_size_0.to(tl.float32)
912+
v_5 = v_0 - v_1
913+
v_6 = v_4 + eps
914+
v_7 = libdevice.rsqrt(v_6)
915+
v_8 = v_5 * v_7
916+
load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
917+
v_9 = load_1.to(tl.float32)
918+
v_10 = v_9[None, :]
919+
v_11 = v_8 * v_10
920+
load_2 = tl.load(bias + indices_1 * bias_stride_0, mask_1, other=0)
921+
v_12 = load_2.to(tl.float32)
922+
v_13 = v_12[None, :]
923+
v_14 = v_11 + v_13
924+
v_15 = v_14.to(tl.float16)
925+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])
926+
927+
def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
928+
m, n = x.size()
929+
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
930+
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
931+
assert len(nomralized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
932+
assert nomralized_shape[0] == n, f'normalized shape mismatch {nomralized_shape[0]} != {n}'
933+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
934+
_BLOCK_SIZE_0 = 32
935+
_RDIM_SIZE_1 = triton.next_power_of_2(bias.size(0))
936+
_launcher(_layer_norm_fwd_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
937+
return out
938+
886939
--- assertExpectedJournal(TestExamples.test_matmul)
887940
from __future__ import annotations
888941

test/test_examples.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,22 @@ def test_fp8_attention(self):
601601
)
602602
)
603603

604+
def test_layernorm(self):
605+
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
606+
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
607+
bias = torch.randn([64], device=DEVICE, dtype=torch.float16)
608+
609+
args = (x, [64], weight, bias)
610+
611+
self.assertExpectedJournal(
612+
check_example(
613+
"layer_norm",
614+
args,
615+
torch.nn.functional.layer_norm(*args),
616+
fn_name="layer_norm_fwd",
617+
)
618+
)
619+
604620

605621
if __name__ == "__main__":
606622
unittest.main()

0 commit comments

Comments
 (0)