Skip to content

Commit f0c5176

Browse files
authored
[RFC] Add static loop unrolling (#216)
1 parent 6fe440c commit f0c5176

File tree

4 files changed

+211
-1
lines changed

4 files changed

+211
-1
lines changed

helion/_compiler/host_function.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ def __init__(
100100
HostFunction.validate_ast(root)
101101

102102
from .device_ir import lower_to_device_ir
103+
from .static_loop_unroller import unroll_static_loops
103104
from .type_propagation import propagate_types
104105

106+
unroll_static_loops(self)
105107
propagate_types(self, fake_args)
106108
env.finalize_config_spec()
107109
self.device_ir = lower_to_device_ir(self)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import TYPE_CHECKING
5+
from typing import NoReturn
6+
7+
from .ast_extension import create
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Sequence
11+
12+
from .host_function import HostFunction
13+
14+
15+
class CannotUnrollLoop(Exception):
16+
pass
17+
18+
19+
class StaticLoopUnroller(ast.NodeTransformer):
20+
"""
21+
A compiler optimization pass that unrolls static for loops.
22+
23+
TODO(oulgen): This pass is primitive, does not handle for.orelse, break, continue etc
24+
"""
25+
26+
def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]:
27+
# Generic visit to handle nested loops
28+
node = self.generic_visit(node) # pyre-ignore[9]
29+
30+
# Check if this is a static loop that can be unrolled
31+
if static_values := self._extract_static_values(node.iter):
32+
return self._unroll_loop(node, static_values)
33+
34+
return node
35+
36+
def visit_Break(self, node: ast.Break) -> NoReturn:
37+
raise CannotUnrollLoop
38+
39+
def visit_Continue(self, node: ast.Continue) -> NoReturn:
40+
raise CannotUnrollLoop
41+
42+
def _extract_static_values(self, iter_node: ast.expr) -> list[ast.expr] | None:
43+
"""
44+
Check if iterator is static, and if so extract those values
45+
"""
46+
if isinstance(iter_node, (ast.List, ast.Tuple)):
47+
return iter_node.elts
48+
return None
49+
50+
def _unroll_loop(
51+
self, loop_node: ast.For, static_values: Sequence[ast.AST]
52+
) -> ast.AST | list[ast.AST]:
53+
unrolled_statements = []
54+
55+
for value in static_values:
56+
assignment = create(
57+
ast.Assign,
58+
targets=[loop_node.target],
59+
value=value,
60+
)
61+
unrolled_statements.append(assignment)
62+
63+
# TODO(oulgen): Should we deepcopy these to avoid reference issues?
64+
unrolled_statements.extend(loop_node.body)
65+
66+
if loop_node.orelse:
67+
raise CannotUnrollLoop
68+
return unrolled_statements
69+
70+
71+
def unroll_static_loops(func: HostFunction) -> None:
72+
new_body = []
73+
for stmt in func.body:
74+
try:
75+
unrolled_stmts = StaticLoopUnroller().visit(stmt)
76+
except CannotUnrollLoop:
77+
new_body.append(stmt)
78+
else:
79+
assert isinstance(unrolled_stmts, ast.stmt)
80+
new_body.append(unrolled_stmts)
81+
func.body = new_body

test/test_errors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -118,9 +120,13 @@ def fn(x: torch.Tensor) -> torch.Tensor:
118120
batch = x.size(0)
119121
out = x.new_empty(batch)
120122
for tile_batch in hl.tile(batch):
121-
for i in [1, 2, 3]:
123+
for i in {1: None, 2: None, 3: None}:
122124
out[tile_batch] = x[tile_batch] + i
123125
return out
124126

125127
with self.assertRaises(helion.exc.InvalidDeviceForLoop):
126128
code_and_output(fn, (torch.randn(8, device=DEVICE),))
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

test/test_loops.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,127 @@ def _chebyshev_kernel_make_precompiler(x: torch.Tensor, w: torch.Tensor):
13781378
return make_precompiler(_chebyshev_kernel_kernel)(x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
13791379
)
13801380

1381+
def test_loop_unroll1(self):
1382+
@helion.kernel()
1383+
def fn(x: torch.Tensor) -> torch.Tensor:
1384+
out = torch.zeros_like(x)
1385+
for tile in hl.tile(x.size()):
1386+
out[tile] = x[tile]
1387+
for i in [1, 2, 3]:
1388+
out[tile] += i
1389+
return out
1390+
1391+
x = torch.randn(4, device=DEVICE)
1392+
code, output = code_and_output(fn, (x,))
1393+
torch.testing.assert_close(output, x + 6)
1394+
self.assertExpectedInline(
1395+
code,
1396+
"""\
1397+
from __future__ import annotations
1398+
1399+
import torch
1400+
import triton
1401+
import triton.language as tl
1402+
1403+
@triton.jit
1404+
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1405+
pid_0 = tl.program_id(0)
1406+
offset_0 = pid_0 * _BLOCK_SIZE_0
1407+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1408+
mask_0 = indices_0 < x_size_0
1409+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1410+
tl.store(out + indices_0 * out_stride_0, load, mask_0)
1411+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1412+
v_0 = 1.0
1413+
v_1 = load_1 + v_0
1414+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1415+
load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1416+
v_2 = 2.0
1417+
v_3 = load_2 + v_2
1418+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
1419+
load_3 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1420+
v_4 = 3.0
1421+
v_5 = load_3 + v_4
1422+
tl.store(out + indices_0 * out_stride_0, v_5, mask_0)
1423+
1424+
def fn(x: torch.Tensor):
1425+
out = torch.zeros_like(x)
1426+
_BLOCK_SIZE_0 = 4
1427+
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1428+
return out
1429+
1430+
def _fn_make_precompiler(x: torch.Tensor):
1431+
out = torch.zeros_like(x)
1432+
_BLOCK_SIZE_0 = 4
1433+
from helion.runtime.precompile_shim import make_precompiler
1434+
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
1435+
)
1436+
1437+
def test_loop_unroll2(self):
1438+
@helion.kernel()
1439+
def fn(x: torch.Tensor) -> torch.Tensor:
1440+
out = torch.zeros_like(x)
1441+
a = 1
1442+
b = 2
1443+
c = 3
1444+
for tile in hl.tile(x.size()):
1445+
out[tile] = x[tile]
1446+
for i in (a, b, c):
1447+
out[tile] += i
1448+
return out
1449+
1450+
x = torch.randn(4, device=DEVICE)
1451+
code, output = code_and_output(fn, (x,))
1452+
torch.testing.assert_close(output, x + 6)
1453+
self.assertExpectedInline(
1454+
code,
1455+
"""\
1456+
from __future__ import annotations
1457+
1458+
import torch
1459+
import triton
1460+
import triton.language as tl
1461+
1462+
@triton.jit
1463+
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1464+
pid_0 = tl.program_id(0)
1465+
offset_0 = pid_0 * _BLOCK_SIZE_0
1466+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1467+
mask_0 = indices_0 < x_size_0
1468+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1469+
tl.store(out + indices_0 * out_stride_0, load, mask_0)
1470+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1471+
v_0 = 1.0
1472+
v_1 = load_1 + v_0
1473+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1474+
load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1475+
v_2 = 2.0
1476+
v_3 = load_2 + v_2
1477+
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
1478+
load_3 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1479+
v_4 = 3.0
1480+
v_5 = load_3 + v_4
1481+
tl.store(out + indices_0 * out_stride_0, v_5, mask_0)
1482+
1483+
def fn(x: torch.Tensor):
1484+
out = torch.zeros_like(x)
1485+
a = 1
1486+
b = 2
1487+
c = 3
1488+
_BLOCK_SIZE_0 = 4
1489+
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1490+
return out
1491+
1492+
def _fn_make_precompiler(x: torch.Tensor):
1493+
out = torch.zeros_like(x)
1494+
a = 1
1495+
b = 2
1496+
c = 3
1497+
_BLOCK_SIZE_0 = 4
1498+
from helion.runtime.precompile_shim import make_precompiler
1499+
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
1500+
)
1501+
13811502

13821503
if __name__ == "__main__":
13831504
unittest.main()

0 commit comments

Comments
 (0)