Skip to content

Commit d71b5c6

Browse files
authored
Improve loop end bound optimization for nested tiling (#192)
1 parent 248ece6 commit d71b5c6

10 files changed

+275
-60
lines changed

helion/_compiler/compile_environment.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
256256
def size_hint(self, n: int | torch.SymInt) -> int:
257257
if isinstance(n, torch.SymInt):
258258
expr = n._sympy_()
259-
if any(s.name.startswith("u") for s in expr.free_symbols):
259+
if _has_unbacked(expr):
260260
# If the size is a symbolic expression with unbacked symbols, then the shape environment
261261
# hint will be wrong since we assign a default value to unbacked symbols. Return a default hint.
262262
return 8192
@@ -489,3 +489,8 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
489489
if isinstance(x, torch.SymInt):
490490
return x._sympy_()
491491
return sympy.sympify(x)
492+
493+
494+
def _has_unbacked(expr: sympy.Expr) -> bool:
495+
# pyre-ignore[16]
496+
return any(n.name.startswith("u") for n in expr.free_symbols)

helion/_compiler/device_function.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88
import threading
99
from typing import TYPE_CHECKING
10+
from typing import NamedTuple
1011
from typing import Protocol
1112
from typing import TypeVar
1213
from typing import cast
@@ -47,6 +48,13 @@ class _TLS(Protocol):
4748
tls: _TLS = cast("_TLS", threading.local())
4849

4950

51+
class VarInfo(NamedTuple):
52+
"""Information about a variable derived from a sympy expression."""
53+
54+
name: str
55+
fx_node: torch.fx.Node
56+
57+
5058
@dataclasses.dataclass
5159
class Argument:
5260
name: str # in the device function
@@ -152,7 +160,7 @@ def __init__(self, name: str, config: Config) -> None:
152160
self._variable_renames: dict[str, list[str]] = {}
153161
self.dce_vars: list[str] = []
154162
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
155-
self.expr_to_var_name: dict[sympy.Expr, str] = {}
163+
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}
156164

157165
from .indexing_strategy import IndexingStrategy
158166
from .tile_dispatch import TileStrategyDispatch
@@ -179,17 +187,17 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
179187
expr = CompileEnvironment.current().shape_env.simplify(expr)
180188
if not expr.free_symbols:
181189
return texpr(expr)
182-
if expr in self.expr_to_var_name:
183-
return self.expr_to_var_name[expr]
190+
if expr in self.expr_to_var_info:
191+
return self.expr_to_var_info[expr].name
184192
expr_to_origin = HostFunction.current().expr_to_origin
185193
if expr in expr_to_origin:
186194
return self._lift_sympy_arg(expr)
187195
replacements = {}
188196
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
189197
assert isinstance(sym, sympy.Symbol)
190-
if sym in self.expr_to_var_name:
198+
if sym in self.expr_to_var_info:
191199
replacements[sym] = sympy.Symbol(
192-
self.expr_to_var_name[sym], integer=True
200+
self.expr_to_var_info[sym].name, integer=True
193201
)
194202
else:
195203
assert sym in expr_to_origin, f"no origin found for {sym.name}"

helion/_compiler/generate_ast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def set_statements(self, new_statements: list[ast.AST] | None) -> Iterator[None]
8080
if new_statements is None:
8181
yield
8282
else:
83-
expr_to_var_name = self.device_function.expr_to_var_name
83+
expr_to_var_info = self.device_function.expr_to_var_info
8484
# We don't want to reuse vars assigned in a nested scope, so copy it
85-
self.device_function.expr_to_var_name = expr_to_var_name.copy()
85+
self.device_function.expr_to_var_info = expr_to_var_info.copy()
8686
self.statements_stack.append(new_statements)
8787
try:
8888
yield
8989
finally:
9090
self.statements_stack.pop()
91-
self.device_function.expr_to_var_name = expr_to_var_name
91+
self.device_function.expr_to_var_info = expr_to_var_info
9292

9393
@contextlib.contextmanager
9494
def set_on_device(self) -> Iterator[None]:

helion/_compiler/indexing_strategy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,9 @@ def is_supported(
436436
if extra_mask is not None:
437437
# TODO(jansel): support block_ptr with extra_mask
438438
return False
439+
input_sizes = collections.deque(fake_tensor.size())
439440
for k in index:
441+
input_size = 1 if k is None else input_sizes.popleft()
440442
if isinstance(k, torch.SymInt):
441443
symbol = k._sympy_()
442444
origin = None
@@ -455,14 +457,13 @@ def is_supported(
455457
In this case, the block masking will be incorrect. So we check if the
456458
masking is needed and bail if it is.
457459
"""
458-
end = loop_state.end_bounds[block_index]
459-
if (
460-
not CompileEnvironment.current()
461-
.block_sizes[block_index]
462-
.size_matches(end)
460+
if not loop_state.block_id_to_info[block_index].is_end_matching(
461+
input_size
463462
):
464463
assert state.fx_node is not None
465464
if "masked_value" in state.fx_node.meta:
465+
# TODO(jansel): in this case we should be able to lower to block_ptr+tl.where
466+
# see test/test_loops.py::TestLoops::test_data_dependent_bounds2
466467
return False
467468
if isinstance(k, torch.Tensor):
468469
# indirect loads don't work with block_ptr

helion/_compiler/inductor_lowering.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .ast_extension import expr_from_string
4747
from .ast_extension import statement_from_string
4848
from .compile_environment import CompileEnvironment
49+
from .device_function import VarInfo
4950
from .node_masking import apply_masking
5051
from .node_masking import cached_masked_value
5152
from .node_masking import getitem_masked_value
@@ -940,11 +941,13 @@ def run_node(self, n: Node) -> object:
940941
# Keep track of what variable symints are stored in to support DeviceFunction.sympy_expr()
941942
expr = CompileEnvironment.current().shape_env.simplify(expr)
942943
if isinstance(result, ast.Name):
943-
self.cg.device_function.expr_to_var_name[expr] = result.id
944+
self.cg.device_function.expr_to_var_info[expr] = VarInfo(
945+
result.id, n
946+
)
944947
else:
945948
assert isinstance(result, ast.Constant)
946-
self.cg.device_function.expr_to_var_name[expr] = repr(
947-
result.value
949+
self.cg.device_function.expr_to_var_info[expr] = VarInfo(
950+
repr(result.value), n
948951
)
949952
return result
950953
if not isinstance(result, (ast.Name, ast.Constant)):

helion/_compiler/reduction_strategy.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,14 @@ def codegen_preamble(self, state: CodegenState) -> None:
177177
f"{mask_var} = {index_var} < {self.fn.sympy_expr(numel)}"
178178
)
179179
# Extract end_var_name from the numel expression
180-
end_var_name = {self.block_index: self.fn.sympy_expr(numel)}
180+
from .tile_strategy import LoopDimInfo
181+
182+
end_var_name = self.fn.sympy_expr(numel)
183+
block_id_to_info = {
184+
self.block_index: LoopDimInfo(end_var_name=end_var_name, end_expr=numel)
185+
}
181186
state.codegen.set_active_loops(
182-
PersistentReductionState(self, end_var_name=end_var_name)
187+
PersistentReductionState(self, block_id_to_info=block_id_to_info)
183188
)
184189

185190
def codegen_reduction(
@@ -258,13 +263,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
258263
type_comment=None,
259264
)
260265
# Extract end_var_name from the actual numel expression used in the range()
261-
end_var_name = {block_index: state.sympy_expr(numel)}
266+
from .tile_strategy import LoopDimInfo
267+
268+
end_var_name = state.sympy_expr(numel)
269+
block_id_to_info = {
270+
block_index: LoopDimInfo(end_var_name=end_var_name, end_expr=numel)
271+
}
262272
return DeviceLoopState(
263273
self,
264274
for_node=for_node,
265275
inner_statements=body,
266-
end_bounds={block_index: numel},
267-
end_var_name=end_var_name,
276+
block_id_to_info=block_id_to_info,
268277
)
269278

270279
def codegen_reduction(

0 commit comments

Comments
 (0)