Skip to content

Commit 4faf678

Browse files
committed
Quack layer_norm + bias integration
1 parent 9e5d80e commit 4faf678

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

examples/layer_norm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@helion.kernel
2020
def layer_norm_fwd(
2121
x: torch.Tensor,
22-
nomralized_shape: list[int],
22+
normalized_shape: list[int] | tuple[int, ...],
2323
weight: torch.Tensor,
2424
bias: torch.Tensor,
2525
eps: float = 1e-5,
@@ -28,7 +28,7 @@ def layer_norm_fwd(
2828
Performs 1D layer normalization on the input tensor using Helion.
2929
Args:
3030
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
31-
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
31+
normalized_shape (list[int] | tuple[int, ...]): List or tuple containing the dimension to normalize over (should be length 1).
3232
weight (torch.Tensor): Learnable scale parameter of shape [dim].
3333
bias (torch.Tensor): Learnable bias parameter of shape [dim].
3434
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
@@ -38,19 +38,21 @@ def layer_norm_fwd(
3838
m, n = x.size()
3939
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
4040
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
41-
assert len(nomralized_shape) == 1, (
41+
assert len(normalized_shape) == 1, (
4242
"Helion layer norm only supports 1D layer norm currently"
4343
)
44-
assert nomralized_shape[0] == n, (
45-
f"normalized shape mismatch {nomralized_shape[0]} != {n}"
44+
assert normalized_shape[0] == n, (
45+
f"normalized shape mismatch {normalized_shape[0]} != {n}"
4646
)
47-
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
47+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
4848
for tile_m in hl.tile(m):
4949
acc = x[tile_m, :].to(torch.float32)
50-
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
50+
# Compute mean and variance separately for better numerical stability
51+
mean = torch.mean(acc, dim=-1, keepdim=True)
52+
var = torch.mean((acc - mean) ** 2, dim=-1, keepdim=True)
5153
normalized = (acc - mean) * torch.rsqrt(var + eps)
5254
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
53-
out[tile_m, :] = acc
55+
out[tile_m, :] = acc.to(x.dtype)
5456
return out
5557

5658

0 commit comments

Comments
 (0)