19
19
@helion .kernel
20
20
def layer_norm_fwd (
21
21
x : torch .Tensor ,
22
- nomralized_shape : list [int ],
22
+ normalized_shape : list [int ] | tuple [ int , ... ],
23
23
weight : torch .Tensor ,
24
24
bias : torch .Tensor ,
25
25
eps : float = 1e-5 ,
@@ -28,7 +28,7 @@ def layer_norm_fwd(
28
28
Performs 1D layer normalization on the input tensor using Helion.
29
29
Args:
30
30
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
31
- nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
31
+ normalized_shape (list[int] | tuple[int, ...] ): List or tuple containing the dimension to normalize over (should be length 1).
32
32
weight (torch.Tensor): Learnable scale parameter of shape [dim].
33
33
bias (torch.Tensor): Learnable bias parameter of shape [dim].
34
34
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
@@ -38,19 +38,21 @@ def layer_norm_fwd(
38
38
m , n = x .size ()
39
39
assert weight .size (0 ) == n , f"weight size mismatch { weight .size (0 )} != { m } "
40
40
assert bias .size (0 ) == n , f"bias size mismatch { bias .size (0 )} != { m } "
41
- assert len (nomralized_shape ) == 1 , (
41
+ assert len (normalized_shape ) == 1 , (
42
42
"Helion layer norm only supports 1D layer norm currently"
43
43
)
44
- assert nomralized_shape [0 ] == n , (
45
- f"normalized shape mismatch { nomralized_shape [0 ]} != { n } "
44
+ assert normalized_shape [0 ] == n , (
45
+ f"normalized shape mismatch { normalized_shape [0 ]} != { n } "
46
46
)
47
- out = torch .empty ([m , n ], dtype = torch . float16 , device = x .device )
47
+ out = torch .empty ([m , n ], dtype = x . dtype , device = x .device )
48
48
for tile_m in hl .tile (m ):
49
49
acc = x [tile_m , :].to (torch .float32 )
50
- var , mean = torch .var_mean (acc , dim = - 1 , keepdim = True , correction = 0 )
50
+ # Compute mean and variance separately for better numerical stability
51
+ mean = torch .mean (acc , dim = - 1 , keepdim = True )
52
+ var = torch .mean ((acc - mean ) ** 2 , dim = - 1 , keepdim = True )
51
53
normalized = (acc - mean ) * torch .rsqrt (var + eps )
52
54
acc = normalized * (weight [:].to (torch .float32 )) + (bias [:].to (torch .float32 ))
53
- out [tile_m , :] = acc
55
+ out [tile_m , :] = acc . to ( x . dtype )
54
56
return out
55
57
56
58
0 commit comments