Skip to content

Commit 931ea4d

Browse files
authored
Improve error message for unpacking a tile (#125)
1 parent 2e0f346 commit 931ea4d

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

helion/_compiler/type_propagation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,13 +1600,11 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
16001600
try:
16011601
elements = rhs.unpack()
16021602
except NotImplementedError:
1603-
elements = [
1604-
UnknownType(
1605-
self.origin(),
1606-
f"Failed to unpack assignment: {rhs!s}",
1607-
)
1608-
for _ in lhs
1609-
]
1603+
if isinstance(rhs, UnknownType):
1604+
raise exc.TypePropagationError(rhs) from None
1605+
if isinstance(rhs, TileIndexType):
1606+
raise exc.FailedToUnpackTile from None
1607+
raise exc.FailedToUnpackTupleAssign(len(lhs), rhs) from None
16101608
used_star = False
16111609
idx = 0
16121610
for elt in lhs:

helion/exc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ class SpecializeArgType(BaseError):
113113
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
114114

115115

116+
class FailedToUnpackTupleAssign(BaseError):
117+
message = "Failed to unpack values in tuple assignment. Expected a sequence of size {0}, got type: {1!s}."
118+
119+
120+
class FailedToUnpackTile(BaseError):
121+
message = (
122+
"Failed to unpack a tile into a tuple assignment. "
123+
"Expected an sequence, but got a single tile. "
124+
"Did you mix up `hl.tile(x)` and `hl.tile([x])`?"
125+
)
126+
127+
116128
class AssignmentMultipleTargets(NotAllowedOnDevice):
117129
message = "Assignment with multiple targets (a=b=1) is not allowed inside the `hl.tile` or `hl.grid` loop."
118130

test/test_errors.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from expecttest import TestCase
4+
import torch
5+
6+
import helion
7+
from helion._testing import DEVICE
8+
from helion._testing import code_and_output
9+
import helion.language as hl
10+
11+
12+
class TestErrors(TestCase):
13+
maxDiff = 16384
14+
15+
def test_tile_unpacking(self):
16+
@helion.kernel()
17+
def sum_kernel(x: torch.Tensor) -> torch.Tensor:
18+
batch, seq_len, hidden = x.size()
19+
out = x.new_empty(batch, hidden)
20+
for tile_batch, tile_hidden in hl.tile(batch, hidden):
21+
out[tile_batch, tile_hidden] = x[tile_batch, :, tile_hidden].sum(1)
22+
return out
23+
24+
with self.assertRaises(helion.exc.FailedToUnpackTile):
25+
code_and_output(sum_kernel, (torch.randn(2, 3, 4, device=DEVICE),))

0 commit comments

Comments
 (0)