@@ -153,6 +153,75 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
153
153
_launcher(_reduce_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
154
154
return out
155
155
156
+ --- assertExpectedJournal(TestReductions.test_reduction_loops_integer_values)
157
+ from __future__ import annotations
158
+
159
+ import torch
160
+ import triton
161
+ import triton.language as tl
162
+ from torch._inductor.runtime.triton_compat import libdevice
163
+ from helion.runtime import default_launcher as _default_launcher
164
+
165
+ @triton.jit
166
+ def _layer_norm_reduction_kernel(bias, x, weight, out, bias_size_0, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
167
+ pid_0 = tl.program_id(0)
168
+ offset_0 = pid_0 * _BLOCK_SIZE_0
169
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
170
+ mask_0 = indices_0 < m
171
+ var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
172
+ for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
173
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
174
+ mask_1 = rindex_1 < bias_size_0
175
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
176
+ v_0 = load.to(tl.float32)
177
+ v_1 = var_mean_extra_acc + v_0
178
+ var_mean_extra_acc = v_1
179
+ var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
180
+ v_2 = var_mean_extra / bias_size_0.to(tl.float32)
181
+ _mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_2, 0)
182
+ var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
183
+ for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
184
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
185
+ mask_1 = rindex_1 < bias_size_0
186
+ _mask_to_1_copy = _mask_to_1
187
+ load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
188
+ v_3 = load_1.to(tl.float32)
189
+ v_4 = v_3 - _mask_to_1_copy
190
+ v_5 = v_4 * v_4
191
+ v_6 = var_mean_extra_2_acc + v_5
192
+ var_mean_extra_2_acc = v_6
193
+ var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
194
+ v_7 = var_mean_extra_2 / bias_size_0.to(tl.float32)
195
+ v_8 = v_7 + eps
196
+ v_9 = libdevice.rsqrt(v_8)
197
+ for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
198
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
199
+ mask_1 = rindex_1 < bias_size_0
200
+ v_2_copy = v_2
201
+ v_9_copy = v_9
202
+ load_2 = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
203
+ v_10 = load_2.to(tl.float32)
204
+ v_11 = v_10 - v_2_copy
205
+ v_12 = v_11 * v_9_copy
206
+ load_3 = tl.load(weight + rindex_1 * weight_stride_0, mask_1, other=0)
207
+ v_13 = load_3.to(tl.float32)
208
+ v_14 = v_13[None, :]
209
+ v_15 = v_12 * v_14
210
+ load_4 = tl.load(bias + rindex_1 * bias_stride_0, mask_1, other=0)
211
+ v_16 = load_4.to(tl.float32)
212
+ v_17 = v_16[None, :]
213
+ v_18 = v_15 + v_17
214
+ v_19 = v_18.to(tl.float16)
215
+ tl.store(out + (indices_0[:, None] * out_stride_0 + rindex_1[None, :] * out_stride_1), v_19, mask_0[:, None] & mask_1[None, :])
216
+
217
+ def layer_norm_reduction(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
218
+ m, n = x.size()
219
+ out = torch.empty([m, n], dtype=torch.float16, device=x.device)
220
+ _BLOCK_SIZE_0 = 32
221
+ _REDUCTION_BLOCK_1 = 4
222
+ _launcher(_layer_norm_reduction_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
223
+ return out
224
+
156
225
--- assertExpectedJournal(TestReductions.test_sum)
157
226
from __future__ import annotations
158
227
0 commit comments