Skip to content

Commit 5840834

Browse files
authored
Do not create a new variable for tile assignments since tiles are immutable (#334)
1 parent 6db2105 commit 5840834

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

helion/_compiler/device_ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
748748
value = self.visit(node.value)
749749
# For simple variable assignments like `a = b`, we need to create a new
750750
# variable to avoid phi node issues when the source variable gets mutated
751-
if isinstance(node.value, ast.Name) and isinstance(value, torch.Tensor):
751+
if isinstance(node.value, ast.Name) and (
752+
isinstance(value, torch.Tensor) and not isinstance(value, Tile)
753+
):
752754
value = _new_var(value)
753755
self._assign(target, value)
754756
return None

test/test_misc.expected

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@ def kernel(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass, *, _launcher=_de
4343
_launcher(_kernel_kernel, (triton.cdiv(a0.size(0), _BLOCK_SIZE_0),), a0, o0, o1, a0.size(0), a0.stride(0), o0.stride(0), o1.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
4444
return [o0, o1]
4545

46+
--- assertExpectedJournal(TestMisc.test_propagate_tile)
47+
from __future__ import annotations
48+
49+
import torch
50+
import triton
51+
import triton.language as tl
52+
from helion.runtime import default_launcher as _default_launcher
53+
54+
@triton.jit
55+
def _copy_kernel_kernel(a, out, a_size_0, a_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
56+
pid_0 = tl.program_id(0)
57+
offset_0 = pid_0 * _BLOCK_SIZE_0
58+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
59+
mask_0 = indices_0 < a_size_0
60+
load = tl.load(a + indices_0 * a_stride_0, mask_0, other=0)
61+
tl.store(out + indices_0 * out_stride_0, load, mask_0)
62+
63+
def copy_kernel(a: torch.Tensor, *, _launcher=_default_launcher):
64+
out = torch.empty_like(a)
65+
_BLOCK_SIZE_0 = 4
66+
_launcher(_copy_kernel_kernel, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
67+
return out
68+
4669
--- assertExpectedJournal(TestMisc.test_scalar_tensor_item_method)
4770
from __future__ import annotations
4871

test/test_misc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,22 @@ def tuple_unpack_kernel(inp_tuple) -> torch.Tensor:
390390

391391
self.assertExpectedJournal(code)
392392

393+
def test_propagate_tile(self):
394+
@helion.kernel
395+
def copy_kernel(a: torch.Tensor) -> torch.Tensor:
396+
out = torch.empty_like(a)
397+
398+
for tile in hl.tile(a.size(0), block_size=4):
399+
t1 = tile
400+
t2 = tile
401+
out[t2] = a[t1]
402+
return out
403+
404+
args = (torch.randn(16, device=DEVICE, dtype=torch.bfloat16),)
405+
code, result = code_and_output(copy_kernel, args)
406+
torch.testing.assert_close(result, args[0])
407+
self.assertExpectedJournal(code)
408+
393409

394410
if __name__ == "__main__":
395411
unittest.main()

0 commit comments

Comments
 (0)