@@ -62,6 +62,128 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
62
62
_launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
63
63
return out
64
64
65
+ --- assertExpectedJournal(TestReductions.test_fp16_var_mean)
66
+ from __future__ import annotations
67
+
68
+ import torch
69
+ import triton
70
+ import triton.language as tl
71
+ from torch._inductor.runtime.triton_compat import libdevice
72
+ from helion.runtime import default_launcher as _default_launcher
73
+
74
+ @triton.jit
75
+ def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
76
+ pid_0 = tl.program_id(0)
77
+ offset_0 = pid_0 * _BLOCK_SIZE_0
78
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79
+ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
80
+ x_part = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
81
+ v_0 = x_part.to(tl.float32)
82
+ var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
83
+ v_1 = 64
84
+ v_2 = var_mean_extra / v_1.to(tl.float32)
85
+ v_3 = x_part.to(tl.float32)
86
+ v_4 = v_3 - v_2
87
+ v_5 = v_4 * v_4
88
+ var_mean_extra_2 = tl.reshape(tl.sum(v_5, 1), [_BLOCK_SIZE_0, 1])
89
+ v_6 = 64
90
+ v_7 = var_mean_extra_2 / v_6.to(tl.float32)
91
+ v_8 = v_7.to(tl.float16)
92
+ v_9 = v_2.to(tl.float16)
93
+ v_10 = x_part - v_9
94
+ v_11 = v_8.to(tl.float32)
95
+ v_12 = v_11 + eps
96
+ v_13 = libdevice.rsqrt(v_12)
97
+ v_14 = v_10.to(tl.float32)
98
+ v_15 = v_14 * v_13
99
+ load_1 = tl.load(weight + indices_1 * 1, None)
100
+ v_16 = load_1.to(tl.float32)
101
+ v_17 = v_16[None, :]
102
+ v_18 = v_15 * v_17
103
+ load_2 = tl.load(bias + indices_1 * 1, None)
104
+ v_19 = load_2.to(tl.float32)
105
+ v_20 = v_19[None, :]
106
+ v_21 = v_18 + v_20
107
+ v_22 = v_21.to(tl.float16)
108
+ tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_22, None)
109
+
110
+ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
111
+ m, n = x.size()
112
+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
113
+ _BLOCK_SIZE_0 = 32
114
+ _RDIM_SIZE_1 = 64
115
+ _launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
116
+ return out
117
+
118
+ --- assertExpectedJournal(TestReductions.test_fp16_var_mean)
119
+ from __future__ import annotations
120
+
121
+ import torch
122
+ import triton
123
+ import triton.language as tl
124
+ from torch._inductor.runtime.triton_compat import libdevice
125
+ from helion.runtime import default_launcher as _default_launcher
126
+
127
+ @triton.jit
128
+ def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
129
+ pid_0 = tl.program_id(0)
130
+ offset_0 = pid_0 * _BLOCK_SIZE_0
131
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
132
+ var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
133
+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
134
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
135
+ x_part = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
136
+ v_0 = x_part.to(tl.float32)
137
+ v_1 = var_mean_extra_acc + v_0
138
+ var_mean_extra_acc = v_1
139
+ var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
140
+ v_2 = 64
141
+ v_3 = var_mean_extra / v_2.to(tl.float32)
142
+ var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
143
+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
144
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
145
+ v_3_copy = v_3
146
+ x_part_1 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
147
+ v_4 = x_part_1.to(tl.float32)
148
+ v_5 = v_4 - v_3_copy
149
+ v_6 = v_5 * v_5
150
+ v_7 = var_mean_extra_2_acc + v_6
151
+ var_mean_extra_2_acc = v_7
152
+ var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
153
+ v_8 = 64
154
+ v_9 = var_mean_extra_2 / v_8.to(tl.float32)
155
+ v_10 = v_9.to(tl.float16)
156
+ v_11 = v_3.to(tl.float16)
157
+ v_12 = v_10.to(tl.float32)
158
+ v_13 = v_12 + eps
159
+ v_14 = libdevice.rsqrt(v_13)
160
+ for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
161
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
162
+ v_11_copy = v_11
163
+ v_14_copy = v_14
164
+ x_part_2 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
165
+ v_15 = x_part_2 - v_11_copy
166
+ v_16 = v_15.to(tl.float32)
167
+ v_17 = v_16 * v_14_copy
168
+ load_1 = tl.load(weight + rindex_1 * 1, None)
169
+ v_18 = load_1.to(tl.float32)
170
+ v_19 = v_18[None, :]
171
+ v_20 = v_17 * v_19
172
+ load_2 = tl.load(bias + rindex_1 * 1, None)
173
+ v_21 = load_2.to(tl.float32)
174
+ v_22 = v_21[None, :]
175
+ v_23 = v_20 + v_22
176
+ v_24 = v_23.to(tl.float16)
177
+ tl.store(out + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), v_24, None)
178
+
179
+ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
180
+ m, n = x.size()
181
+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
182
+ _BLOCK_SIZE_0 = 32
183
+ _REDUCTION_BLOCK_1 = 8
184
+ _launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
185
+ return out
186
+
65
187
--- assertExpectedJournal(TestReductions.test_mean)
66
188
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
67
189
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
0 commit comments