Skip to content

Commit db41224

Browse files
authored
Support reshape with block_size expressions (#495)
1 parent 01c831e commit db41224

File tree

6 files changed

+147
-22
lines changed

6 files changed

+147
-22
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,16 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7474
self._symint_cache: dict[object, torch.SymInt] = {}
7575

7676
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
77+
from .device_function import contains_only_block_size_symbols
78+
7779
for size in sizes:
7880
if isinstance(size, torch.SymInt):
7981
block_idx = self.get_block_id(size)
8082
if block_idx is None:
8183
value = self.shape_env.replace(size._sympy_())
82-
if value.free_symbols:
84+
if value.free_symbols and not contains_only_block_size_symbols(
85+
value
86+
):
8387
raise exc.ShapeSpecializingAllocation
8488
self.kernel_tensor_sizes[(*map(_to_sympy, sizes),)] += 1
8589

helion/_compiler/device_function.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,40 @@ class VarInfo(NamedTuple):
5858
fx_node: torch.fx.Node
5959

6060

61+
def find_block_size_symbols(
62+
expr: sympy.Expr,
63+
) -> tuple[dict[sympy.Symbol, int], set[sympy.Symbol]]:
64+
"""
65+
Find block size symbols in a sympy expression.
66+
67+
Returns:
68+
tuple of (block_size_mapping, non_block_size_symbols) where:
69+
- block_size_mapping: dict mapping block size symbols to their block_id
70+
- non_block_size_symbols: set of symbols that are NOT block sizes
71+
"""
72+
if not isinstance(expr, sympy.Expr):
73+
return {}, set()
74+
75+
hf = HostFunction.current()
76+
block_sizes = {}
77+
non_block_size_symbols = set()
78+
79+
for symbol in expr.free_symbols:
80+
origin_info = hf.expr_to_origin.get(symbol) # pyright: ignore[reportArgumentType]
81+
if origin_info is None or not isinstance(origin_info.origin, BlockSizeOrigin):
82+
non_block_size_symbols.add(symbol)
83+
else:
84+
block_sizes[symbol] = origin_info.origin.block_id
85+
86+
return block_sizes, non_block_size_symbols
87+
88+
89+
def contains_only_block_size_symbols(expr: sympy.Expr) -> bool:
90+
"""Check if expression contains only block size symbols (no other variables)."""
91+
_, non_block = find_block_size_symbols(expr)
92+
return len(non_block) == 0
93+
94+
6195
@dataclasses.dataclass
6296
class Argument:
6397
name: str # in the device function
@@ -209,6 +243,35 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
209243
def block_size_var(self, block_id: int) -> str | None:
210244
return self.block_size_var_cache.get((block_id,))
211245

246+
def try_map_block_symbols_to_vars(self, expr: sympy.Expr) -> sympy.Expr | None:
247+
"""Try to map all block size symbols in expression to their variable names.
248+
249+
Returns:
250+
- The expression with symbols replaced if ALL symbols are block sizes and have variables
251+
- None if the expression contains non-block symbols or unmapped block symbols
252+
"""
253+
block_mapping, non_block_symbols = find_block_size_symbols(expr)
254+
255+
# Can't map if there are non-block symbols
256+
if non_block_symbols:
257+
return None
258+
259+
# No symbols to map - return as-is
260+
if not block_mapping:
261+
return expr
262+
263+
# Try to map all block symbols to their variables
264+
var_map = {}
265+
for symbol, block_id in block_mapping.items():
266+
block_var = self.block_size_var(block_id)
267+
if not block_var:
268+
# Can't map this block symbol - fail
269+
return None
270+
var_map[symbol] = sympy.Symbol(block_var, integer=True)
271+
272+
# Successfully mapped all symbols
273+
return expr.xreplace(var_map)
274+
212275
def merge_variable_names(self, a: str, b: str) -> None:
213276
name_group = [
214277
*self._variable_renames.get(a, [a]),

helion/_compiler/tile_dispatch.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
import operator
55
from typing import TYPE_CHECKING
66

7+
import sympy
8+
import torch
9+
710
from .compile_environment import CompileEnvironment
811
from .device_function import DeviceFunction
12+
from .device_function import texpr
913
from .device_ir import ForLoopGraphInfo
1014
from .device_ir import ReductionLoopGraphInfo
1115
from .host_function import HostFunction
@@ -21,9 +25,6 @@
2125
if TYPE_CHECKING:
2226
from collections.abc import Sequence
2327

24-
import sympy
25-
import torch
26-
2728
from .. import Config
2829
from .inductor_lowering import CodegenState
2930

@@ -120,9 +121,9 @@ def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]:
120121
for idx, shape in enumerate(shapes):
121122
block_idx = CompileEnvironment.current().get_block_id(shape)
122123
if block_idx is None:
123-
compacted_shapes.append(
124-
CompactedShape(self.strategies[0].fn.literal_expr(shape), [idx], [])
125-
)
124+
# Check if this is a symbolic expression with block sizes
125+
shape_str = self._get_shape_string(shape)
126+
compacted_shapes.append(CompactedShape(shape_str, [idx], []))
126127
else:
127128
block_size = DeviceFunction.current().block_size_var(block_idx)
128129
if block_size is None:
@@ -132,6 +133,24 @@ def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]:
132133
compacted_shapes = strategy.compact_shape(compacted_shapes)
133134
return compacted_shapes
134135

136+
def _get_shape_string(self, shape: SymIntLike) -> str:
137+
"""Get string representation of a shape"""
138+
# Extract sympy expression
139+
if isinstance(shape, torch.SymInt):
140+
expr = shape._sympy_()
141+
elif isinstance(shape, sympy.Expr):
142+
expr = shape
143+
else:
144+
return self.strategies[0].fn.literal_expr(shape)
145+
146+
# Try to map block symbols to their variable names
147+
mapped_expr = DeviceFunction.current().try_map_block_symbols_to_vars(expr)
148+
if mapped_expr is not None:
149+
return texpr(mapped_expr)
150+
151+
# Fallback: use literal expression if mapping failed
152+
return self.strategies[0].fn.literal_expr(shape)
153+
135154
def shape_str(self, shape: ShapeLike) -> str:
136155
compacted_shapes = self._compact_shape(shape)
137156
result = [s.size_str for s in compacted_shapes]

helion/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def convert_size_arg(size: object) -> object:
3030
"""Convert a size argument that may contain RefTile objects.
3131
3232
Handles:
33-
- Single RefTile -> int
33+
- Single RefTile -> int (block_size)
3434
- List/tuple containing RefTiles -> list with converted sizes
3535
- Other values -> unchanged
3636
"""
@@ -40,7 +40,7 @@ def convert_size_arg(size: object) -> object:
4040
if isinstance(size, (list, tuple)):
4141
return [convert_size_arg(item) for item in size]
4242
if isinstance(size, RefTile):
43-
return size._slice.stop - size._slice.start
43+
return size._block_size
4444
return size
4545

4646

helion/runtime/ref_mode.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def __init__(self) -> None:
128128
torch.Tensor.view: lambda args, kwargs: self._handle_size_arg_method(
129129
args, kwargs, "view"
130130
),
131+
torch.Tensor.reshape: lambda args, kwargs: self._handle_size_arg_method(
132+
args, kwargs, "reshape"
133+
),
131134
torch.reshape: lambda args, kwargs: self._handle_size_arg_method(
132135
args, kwargs, "reshape"
133136
),
@@ -217,20 +220,20 @@ def _handle_size_arg_method(
217220
tensor = cast("torch.Tensor", args[0])
218221

219222
if method_name == "reshape":
220-
# torch.reshape expects shape as a single tuple/list argument
221-
# It can be passed as torch.reshape(tensor, shape) or torch.reshape(tensor, shape=shape)
222-
shape = args[1] if len(args) > 1 else kwargs.get("shape")
223-
if shape is not None:
224-
shape = convert_size_arg(shape)
225-
if len(args) > 1:
226-
return torch.reshape(
227-
tensor,
228-
shape, # type: ignore[arg-type]
229-
*args[2:],
230-
**kwargs,
231-
)
223+
# reshape can take shape as multiple positional args or as a single tuple/list
224+
# e.g., tensor.reshape(2, 3) or tensor.reshape((2, 3))
225+
if "shape" in kwargs:
226+
# Handle kwargs case: tensor.reshape(shape=(2, 3))
227+
shape = convert_size_arg(kwargs["shape"])
228+
kwargs = dict(kwargs) # Make a copy to avoid modifying the original
232229
kwargs["shape"] = shape
233-
return torch.reshape(tensor, **kwargs) # type: ignore[arg-type]
230+
return torch.reshape(tensor, **kwargs) # type: ignore[arg-type]
231+
# Handle positional args case
232+
sizes = args[1:]
233+
new_sizes = convert_size_arg(sizes)
234+
method = getattr(tensor, method_name)
235+
assert isinstance(new_sizes, list)
236+
return method(*new_sizes, **kwargs)
234237

235238
# view/expand take sizes as positional args
236239
sizes = args[1:]

test/test_views.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,42 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
156156
_code, result = code_and_output(fn, args)
157157
torch.testing.assert_close(result, args[0] + args[1])
158158

159+
def test_reshape_input_types(self):
160+
@helion.kernel(static_shapes=True)
161+
def reshape_reduction_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
162+
m, k = x.size()
163+
k2, n = y.size()
164+
assert k == k2, f"size mismatch {k} != {k2}"
165+
166+
out = torch.zeros(
167+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
168+
)
169+
170+
for tile_m, tile_n in hl.tile([m, n]):
171+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
172+
for tile_k in hl.tile(k):
173+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
174+
175+
# Test different reshape input types
176+
reshaped_acc = acc.reshape(-1, tile_m.block_size * tile_n.block_size)
177+
reshaped_acc = reshaped_acc.reshape(
178+
tile_m.block_size, tile_n.block_size
179+
)
180+
reshaped_acc = reshaped_acc.flatten(0)
181+
reshaped_acc = reshaped_acc.reshape(tile_m, tile_n)
182+
reshaped_acc = reshaped_acc.reshape(
183+
tile_m.block_size * 2 // 2, tile_n.block_size + 1 - 1
184+
)
185+
out[tile_m, tile_n] = reshaped_acc
186+
187+
return out
188+
189+
x = torch.randn(8, 16, device=DEVICE)
190+
y = torch.randn(16, 32, device=DEVICE)
191+
_code, result = code_and_output(reshape_reduction_dim, (x, y))
192+
expected = torch.matmul(x, y)
193+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
194+
159195

160196
if __name__ == "__main__":
161197
unittest.main()

0 commit comments

Comments
 (0)