@@ -883,6 +883,59 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
883
883
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
884
884
return out
885
885
886
+ --- assertExpectedJournal(TestExamples.test_layernorm)
887
+ from __future__ import annotations
888
+
889
+ import torch
890
+ import triton
891
+ import triton.language as tl
892
+ from torch._inductor.runtime.triton_compat import libdevice
893
+ from helion.runtime import default_launcher as _default_launcher
894
+
895
+ @triton.jit
896
+ def _layer_norm_fwd_kernel(bias, x, weight, out, bias_size_0, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
897
+ pid_0 = tl.program_id(0)
898
+ offset_0 = pid_0 * _BLOCK_SIZE_0
899
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
900
+ mask_0 = indices_0 < m
901
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
902
+ mask_1 = indices_1 < bias_size_0
903
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
904
+ v_0 = load.to(tl.float32)
905
+ var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
906
+ v_1 = var_mean_extra / bias_size_0.to(tl.float32)
907
+ _mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_1, 0)
908
+ v_2 = v_0 - _mask_to_1
909
+ v_3 = v_2 * v_2
910
+ var_mean_extra_2 = tl.reshape(tl.sum(v_3, 1), [_BLOCK_SIZE_0, 1])
911
+ v_4 = var_mean_extra_2 / bias_size_0.to(tl.float32)
912
+ v_5 = v_0 - v_1
913
+ v_6 = v_4 + eps
914
+ v_7 = libdevice.rsqrt(v_6)
915
+ v_8 = v_5 * v_7
916
+ load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
917
+ v_9 = load_1.to(tl.float32)
918
+ v_10 = v_9[None, :]
919
+ v_11 = v_8 * v_10
920
+ load_2 = tl.load(bias + indices_1 * bias_stride_0, mask_1, other=0)
921
+ v_12 = load_2.to(tl.float32)
922
+ v_13 = v_12[None, :]
923
+ v_14 = v_11 + v_13
924
+ v_15 = v_14.to(tl.float16)
925
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])
926
+
927
+ def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
928
+ m, n = x.size()
929
+ assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
930
+ assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
931
+ assert len(nomralized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
932
+ assert nomralized_shape[0] == n, f'normalized shape mismatch {nomralized_shape[0]} != {n}'
933
+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
934
+ _BLOCK_SIZE_0 = 32
935
+ _RDIM_SIZE_1 = triton.next_power_of_2(bias.size(0))
936
+ _launcher(_layer_norm_fwd_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
937
+ return out
938
+
886
939
--- assertExpectedJournal(TestExamples.test_matmul)
887
940
from __future__ import annotations
888
941
0 commit comments