Skip to content

Commit 9554e7e

Browse files
authored
Add support for torch.arange (pytorch#215)
1 parent f0c5176 commit 9554e7e

File tree

6 files changed

+204
-3
lines changed

6 files changed

+204
-3
lines changed

helion/_compiler/compile_environment.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7171
)
7272
self.specialized_vars: set[sympy.Symbol] = set()
7373
self.loop_dependency_checker = LoopDependencyChecker()
74+
self._symint_cache: dict[object, torch.SymInt] = {}
7475

7576
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
7677
for size in sizes:
@@ -174,6 +175,30 @@ def create_unbacked_symint(self, hint: int = 8192) -> torch.SymInt:
174175
self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(hint)
175176
return sym
176177

178+
def cached_create_unbacked_symint(
179+
self, key: Sequence[object], hint: int = 8192
180+
) -> torch.SymInt:
181+
"""Create an unbacked symint with caching based on a key.
182+
183+
This ensures that the same key always returns the same unbacked
184+
symint, which is crucial to allow simplification of expressions
185+
for things like tile_begin.
186+
187+
Args:
188+
key: The cache key (should be sequence of hashables and unique for the desired symint)
189+
hint: Hint value for the symint
190+
191+
Returns:
192+
A consistent unbacked symint for the given key
193+
"""
194+
# pyre-ignore[16]
195+
key = tuple([x._sympy_() if hasattr(x, "_sympy_") else x for x in key])
196+
result = self._symint_cache.get(key)
197+
if result is None:
198+
result = self.create_unbacked_symint(hint)
199+
self._symint_cache[key] = result
200+
return result
201+
177202
def to_fake(self, obj: object, origin: Origin) -> object:
178203
if isinstance(obj, torch.Tensor):
179204
return self._to_fake_tensor(obj, origin.to_source())

helion/_compiler/inductor_lowering.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,20 @@ def __init__(self, graph: torch.fx.Graph, cg: GenerateAST) -> None:
871871
super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False)
872872
self.cg = cg
873873

874+
def to_ast(self, value: object) -> ast.AST:
875+
"""
876+
Convert a value to an AST expression.
877+
"""
878+
if isinstance(value, torch.fx.Node):
879+
result = self.env[value]
880+
assert isinstance(result, ast.AST)
881+
return result
882+
if isinstance(value, (int, float, bool)):
883+
return create(ast.Constant, value=value)
884+
if isinstance(value, ast.AST):
885+
return value
886+
raise TypeError(f"Unsupported value type for AST conversion: {type(value)}")
887+
874888
def _collect_multi_outputs(
875889
self, node: Node, last_node_result: object
876890
) -> tuple[object, ...]:
@@ -1018,3 +1032,29 @@ def add_statement(self, statement: ast.AST | str) -> None:
10181032

10191033
def sympy_expr(self, expr: sympy.Expr) -> str:
10201034
return self.codegen.device_function.sympy_expr(expr)
1035+
1036+
1037+
# pyre-fixme[56]
1038+
@register_lowering(torch.ops.prims.iota.default)
1039+
def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
1040+
"""Generate tl.arange for torch.ops.prims.iota.default operations."""
1041+
start = node.kwargs.get("start", 0)
1042+
step = node.kwargs.get("step", 1)
1043+
dtype = (
1044+
node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype
1045+
)
1046+
assert isinstance(dtype, torch.dtype)
1047+
(length_arg,) = node.args # expecting a single argument for length
1048+
expr = "tl.arange(0, length)"
1049+
if step != 1:
1050+
expr = f"step * {expr}"
1051+
if start != 0:
1052+
expr = f"start + {expr}"
1053+
if dtype != torch.int32:
1054+
expr = f"({expr}).to({triton_type(dtype)})"
1055+
return expr_from_string(
1056+
expr,
1057+
start=ctx.to_ast(start),
1058+
step=ctx.to_ast(step),
1059+
length=ctx.to_ast(length_arg),
1060+
)

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .constexpr import ConstExpr as constexpr # noqa: F401
44
from .constexpr import specialize as specialize
5+
from .creation_ops import arange as arange
56
from .creation_ops import full as full
67
from .creation_ops import zeros as zeros
78
from .device_print import device_print as device_print

helion/language/creation_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,27 @@ def _(
7878
value = node.args[1]
7979
assert isinstance(value, (int, float, bool))
8080
return value
81+
82+
83+
def arange(
84+
*args: int,
85+
dtype: torch.dtype | None = None,
86+
**kwargs: object,
87+
) -> torch.Tensor:
88+
"""
89+
Same as `torch.arange()`, but defaults to same device as the current kernel.
90+
91+
Example usage:
92+
hl.arange(tile.block_size) # [0, 1, ..., tile.block_size - 1]
93+
hl.arange(tile.begin, tile.begin + tile.block_size) # same as tile.index
94+
"""
95+
env = CompileEnvironment.current()
96+
if dtype is None:
97+
dtype = env.settings.index_dtype
98+
return torch.arange(
99+
*args,
100+
# pyre-ignore[6]
101+
**kwargs,
102+
dtype=dtype,
103+
device=env.device,
104+
)

helion/language/tile_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def tile_begin(tile: Tile) -> int:
6060
@_decorators.register_fake(tile_begin)
6161
def _(tile: torch.SymInt) -> torch.SymInt:
6262
_disable_flatten_get_tile(tile) # update config spec if needed
63-
return CompileEnvironment.current().create_unbacked_symint()
63+
return CompileEnvironment.current().cached_create_unbacked_symint(
64+
("tile_begin", tile)
65+
)
6466

6567

6668
def _disable_flatten_get_tile(tile: object) -> int:
@@ -94,7 +96,9 @@ def tile_end(tile: Tile) -> int:
9496
@_decorators.register_fake(tile_end)
9597
def _(tile: torch.SymInt) -> torch.SymInt:
9698
_disable_flatten_get_tile(tile) # update config spec if needed
97-
return CompileEnvironment.current().create_unbacked_symint()
99+
return CompileEnvironment.current().cached_create_unbacked_symint(
100+
("tile_end", tile)
101+
)
98102

99103

100104
@_decorators.codegen(tile_end)
@@ -148,7 +152,7 @@ def tile_id(tile: Tile) -> int:
148152
@_decorators.register_fake(tile_id)
149153
def _(tile: torch.SymInt) -> torch.SymInt:
150154
assert isinstance(tile, torch.SymInt)
151-
return CompileEnvironment.current().create_unbacked_symint()
155+
return CompileEnvironment.current().cached_create_unbacked_symint(("tile_id", tile))
152156

153157

154158
@_decorators.codegen(tile_id)

test/test_indexing.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,113 @@ def fn(x: torch.Tensor) -> torch.Tensor:
392392
)
393393
torch.testing.assert_close(result, expected)
394394

395+
def test_arange_tile_block_size(self):
396+
@helion.kernel(use_default_config=True)
397+
def arange_from_block_size(x: torch.Tensor) -> torch.Tensor:
398+
out = torch.zeros([x.size(0)], dtype=torch.int32, device=x.device)
399+
for tile in hl.tile(x.size(0)):
400+
# Test the exact pattern requested: torch.arange(tile.block_size, device=x.device)
401+
out[tile] = torch.arange(tile.block_size, device=x.device)
402+
return out
403+
404+
x = torch.randn([64], device=DEVICE)
405+
code, result = code_and_output(
406+
arange_from_block_size,
407+
(x,),
408+
block_size=16,
409+
)
410+
expected = torch.arange(16, dtype=torch.int32, device=DEVICE).repeat(4)
411+
torch.testing.assert_close(result, expected)
412+
413+
def test_arange_two_args(self):
414+
@helion.kernel(use_default_config=True)
415+
def arange_two_args(x: torch.Tensor) -> torch.Tensor:
416+
out = torch.zeros([x.size(0)], dtype=torch.int32, device=x.device)
417+
for tile in hl.tile(x.size(0)):
418+
# Test the exact pattern requested: torch.arange(tile.begin, tile.begin+tile.block_size, device=x.device)
419+
out[tile] = torch.arange(
420+
tile.begin, tile.begin + tile.block_size, device=x.device
421+
)
422+
return out
423+
424+
x = torch.randn([64], device=DEVICE)
425+
code, result = code_and_output(
426+
arange_two_args,
427+
(x,),
428+
block_size=16,
429+
)
430+
expected = torch.arange(64, dtype=torch.int32, device=DEVICE)
431+
torch.testing.assert_close(result, expected)
432+
433+
def test_arange_three_args_step(self):
434+
@helion.kernel(config={"block_size": 8})
435+
def arange_three_args_step(x: torch.Tensor) -> torch.Tensor:
436+
out = torch.zeros([x.size(0) // 2], dtype=torch.int32, device=x.device)
437+
for tile in hl.tile(x.size(0) // 2):
438+
# Test the exact pattern requested: torch.arange(start, end, step=2, device=x.device)
439+
start_idx = tile.begin * 2
440+
end_idx = (tile.begin + tile.block_size) * 2
441+
out[tile] = torch.arange(start_idx, end_idx, step=2, device=x.device)
442+
return out
443+
444+
x = torch.randn([64], device=DEVICE)
445+
code, result = code_and_output(
446+
arange_three_args_step,
447+
(x,),
448+
)
449+
expected = torch.arange(0, 64, step=2, dtype=torch.int32, device=DEVICE)
450+
torch.testing.assert_close(result, expected)
451+
self.assertExpectedInline(
452+
code,
453+
"""\
454+
from __future__ import annotations
455+
456+
import torch
457+
import triton
458+
import triton.language as tl
459+
460+
@triton.jit
461+
def _arange_three_args_step_kernel(out, out_size_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
462+
pid_0 = tl.program_id(0)
463+
offset_0 = pid_0 * _BLOCK_SIZE_0
464+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
465+
mask_0 = indices_0 < out_size_0
466+
mul = 2 * offset_0
467+
iota = (mul + 2 * tl.arange(0, _BLOCK_SIZE_0)).to(tl.int64)
468+
v_0 = iota.to(tl.int32)
469+
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
470+
471+
def arange_three_args_step(x: torch.Tensor):
472+
out = torch.zeros([x.size(0) // 2], dtype=torch.int32, device=x.device)
473+
_BLOCK_SIZE_0 = 8
474+
_arange_three_args_step_kernel[triton.cdiv(out.size(0), _BLOCK_SIZE_0),](out, out.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
475+
return out
476+
477+
def _arange_three_args_step_make_precompiler(x: torch.Tensor):
478+
out = torch.zeros([x.size(0) // 2], dtype=torch.int32, device=x.device)
479+
_BLOCK_SIZE_0 = 8
480+
from helion.runtime.precompile_shim import make_precompiler
481+
return make_precompiler(_arange_three_args_step_kernel)(out, out.size(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
482+
)
483+
484+
def test_arange_hl_alias(self):
485+
@helion.kernel(config={"block_size": 8})
486+
def arange_three_args_step(x: torch.Tensor) -> torch.Tensor:
487+
out = torch.zeros([x.size(0) // 2], dtype=torch.int32, device=x.device)
488+
for tile in hl.tile(x.size(0) // 2):
489+
start_idx = tile.begin * 2
490+
end_idx = (tile.begin + tile.block_size) * 2
491+
out[tile] = hl.arange(start_idx, end_idx, step=2)
492+
return out
493+
494+
x = torch.randn([64], device=DEVICE)
495+
code, result = code_and_output(
496+
arange_three_args_step,
497+
(x,),
498+
)
499+
expected = torch.arange(0, 64, step=2, dtype=torch.int32, device=DEVICE)
500+
torch.testing.assert_close(result, expected)
501+
395502

396503
if __name__ == "__main__":
397504
unittest.main()

0 commit comments

Comments
 (0)