Skip to content

Commit a252cca

Browse files
authored
Fix scalar tensor broadcasting in type propagation (#425)
1 parent bc62cf2 commit a252cca

File tree

3 files changed

+2
-25
lines changed

3 files changed

+2
-25
lines changed

helion/_compiler/type_propagation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def propagate_setitem(
477477
lhs_rank = len(lhs_shape)
478478
if isinstance(value, TensorType):
479479
rhs_rank = value.fake_value.ndim
480-
if lhs_rank != rhs_rank:
480+
# Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts)
481+
if rhs_rank != 0 and lhs_rank != rhs_rank:
481482
raise exc.RankMismatch(
482483
lhs_rank,
483484
rhs_rank,

test/test_errors.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
6060
(torch.randn(4, 8, 16, device=DEVICE),),
6161
)
6262

63-
def test_rank_mismatch_assignment(self):
64-
"""Test that RankMismatch shows tensor shapes in assignment errors."""
65-
66-
@helion.kernel()
67-
def fn(x: torch.Tensor) -> torch.Tensor:
68-
batch, seq_len = x.size()
69-
out = x.new_empty(batch, seq_len)
70-
for tile_batch, tile_seq in hl.tile([batch, seq_len]):
71-
scalar_val = x[tile_batch, 0].sum() # Creates 0D tensor
72-
out[tile_batch, tile_seq] = scalar_val # 0D -> 2D assignment
73-
return out
74-
75-
with self.assertRaisesRegex(
76-
helion.exc.RankMismatch,
77-
r"Expected ndim=2, but got ndim=0.*You have too few indices",
78-
):
79-
code_and_output(fn, (torch.randn(4, 8, device=DEVICE),))
80-
8163
def test_rank_mismatch_indexing(self):
8264
"""Test that RankMismatch shows tensor shapes in indexing errors."""
8365

test/test_indexing.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -608,9 +608,6 @@ def kernel(
608608
torch.testing.assert_close(results[4], expected_symint)
609609
torch.testing.assert_close(results[5], expected_symint)
610610

611-
@skipIfNormalMode(
612-
"RankMismatch: Expected ndim=1, but got ndim=0 - LHS/RHS shape mismatch in type_propagation.py"
613-
)
614611
def test_1d_slice_from_indexed_value(self):
615612
"""buf[:] = zeros[i] - Assign slice from indexed value"""
616613

@@ -866,9 +863,6 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor:
866863
expected = torch.zeros([N], device=DEVICE)
867864
torch.testing.assert_close(result, expected)
868865

869-
@skipIfNormalMode(
870-
"RankMismatch: Expected ndim=1, but got ndim=0 - broadcasting shape mismatch"
871-
)
872866
def test_broadcast(self):
873867
"""Test both setter from scalar and getter for [:, i]"""
874868

0 commit comments

Comments
 (0)