Skip to content

Commit d593642

Browse files
authored
Fix issue with integer in rolled reduction (#354)
Fixes #345
1 parent 975959b commit d593642

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

helion/language/_tracing_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def _get_symnode(debug_name: str) -> int:
4242
@_decorators.codegen(_get_symnode)
4343
def _(state: CodegenState) -> ast.AST:
4444
val = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess]
45+
46+
# Handle the case where val is a regular integer (e.g., from reduction_loops config)
47+
if isinstance(val, int):
48+
return expr_from_string(str(val))
49+
4550
assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val
4651
if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None: # pyright: ignore[reportArgumentType]
4752
block_size_var = state.device_function.block_size_var(block_idx)

test/test_reductions.expected

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,75 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
153153
_launcher(_reduce_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
154154
return out
155155

156+
--- assertExpectedJournal(TestReductions.test_reduction_loops_integer_values)
157+
from __future__ import annotations
158+
159+
import torch
160+
import triton
161+
import triton.language as tl
162+
from torch._inductor.runtime.triton_compat import libdevice
163+
from helion.runtime import default_launcher as _default_launcher
164+
165+
@triton.jit
166+
def _layer_norm_reduction_kernel(bias, x, weight, out, bias_size_0, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
167+
pid_0 = tl.program_id(0)
168+
offset_0 = pid_0 * _BLOCK_SIZE_0
169+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
170+
mask_0 = indices_0 < m
171+
var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
172+
for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
173+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
174+
mask_1 = rindex_1 < bias_size_0
175+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
176+
v_0 = load.to(tl.float32)
177+
v_1 = var_mean_extra_acc + v_0
178+
var_mean_extra_acc = v_1
179+
var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
180+
v_2 = var_mean_extra / bias_size_0.to(tl.float32)
181+
_mask_to_1 = tl.where(tl.broadcast_to(mask_0[:, None], [_BLOCK_SIZE_0, 1]), v_2, 0)
182+
var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
183+
for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
184+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
185+
mask_1 = rindex_1 < bias_size_0
186+
_mask_to_1_copy = _mask_to_1
187+
load_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
188+
v_3 = load_1.to(tl.float32)
189+
v_4 = v_3 - _mask_to_1_copy
190+
v_5 = v_4 * v_4
191+
v_6 = var_mean_extra_2_acc + v_5
192+
var_mean_extra_2_acc = v_6
193+
var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
194+
v_7 = var_mean_extra_2 / bias_size_0.to(tl.float32)
195+
v_8 = v_7 + eps
196+
v_9 = libdevice.rsqrt(v_8)
197+
for roffset_1 in tl.range(0, bias_size_0, _REDUCTION_BLOCK_1):
198+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
199+
mask_1 = rindex_1 < bias_size_0
200+
v_2_copy = v_2
201+
v_9_copy = v_9
202+
load_2 = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
203+
v_10 = load_2.to(tl.float32)
204+
v_11 = v_10 - v_2_copy
205+
v_12 = v_11 * v_9_copy
206+
load_3 = tl.load(weight + rindex_1 * weight_stride_0, mask_1, other=0)
207+
v_13 = load_3.to(tl.float32)
208+
v_14 = v_13[None, :]
209+
v_15 = v_12 * v_14
210+
load_4 = tl.load(bias + rindex_1 * bias_stride_0, mask_1, other=0)
211+
v_16 = load_4.to(tl.float32)
212+
v_17 = v_16[None, :]
213+
v_18 = v_15 + v_17
214+
v_19 = v_18.to(tl.float16)
215+
tl.store(out + (indices_0[:, None] * out_stride_0 + rindex_1[None, :] * out_stride_1), v_19, mask_0[:, None] & mask_1[None, :])
216+
217+
def layer_norm_reduction(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
218+
m, n = x.size()
219+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
220+
_BLOCK_SIZE_0 = 32
221+
_REDUCTION_BLOCK_1 = 4
222+
_launcher(_layer_norm_reduction_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
223+
return out
224+
156225
--- assertExpectedJournal(TestReductions.test_sum)
157226
from __future__ import annotations
158227

test/test_reductions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,59 @@ def test_argmin_argmax_looped(self):
135135
torch.testing.assert_close(output, args[1](args[0], dim=-1))
136136
self.assertExpectedJournal(code)
137137

138+
def test_reduction_loops_integer_values(self):
139+
"""Test that reduction_loops with integer values works (issue #345 fix)."""
140+
141+
@helion.kernel(use_default_config=True)
142+
def layer_norm_reduction(
143+
x: torch.Tensor,
144+
weight: torch.Tensor,
145+
bias: torch.Tensor,
146+
eps: float = 1e-5,
147+
) -> torch.Tensor:
148+
m, n = x.size()
149+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
150+
151+
for tile_m in hl.tile(m):
152+
acc = x[tile_m, :].to(torch.float32)
153+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
154+
normalized = (acc - mean) * torch.rsqrt(var + eps)
155+
result = normalized * (weight[:].to(torch.float32)) + (
156+
bias[:].to(torch.float32)
157+
)
158+
out[tile_m, :] = result
159+
return out
160+
161+
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
162+
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
163+
bias = torch.randn([64], device=DEVICE, dtype=torch.float16)
164+
eps = 1e-4
165+
166+
args = (x, weight, bias, eps)
167+
168+
# Test various reduction_loops configurations that previously failed
169+
for reduction_loop_value in [2, 4, 8]:
170+
with self.subTest(reduction_loop=reduction_loop_value):
171+
code, output = code_and_output(
172+
layer_norm_reduction,
173+
args,
174+
block_size=32,
175+
reduction_loop=reduction_loop_value,
176+
)
177+
178+
# Compute expected result using PyTorch's layer_norm
179+
expected = torch.nn.functional.layer_norm(
180+
x.float(), [64], weight.float(), bias.float(), eps
181+
).half()
182+
183+
torch.testing.assert_close(output, expected, rtol=1e-2, atol=1e-2)
184+
185+
# Only check the generated code for one configuration to avoid redundant expected outputs
186+
code, _ = code_and_output(
187+
layer_norm_reduction, args, block_size=32, reduction_loop=4
188+
)
189+
self.assertExpectedJournal(code)
190+
138191

139192
if __name__ == "__main__":
140193
unittest.main()

0 commit comments

Comments
 (0)