Skip to content

Commit bc62cf2

Browse files
authored
Fix scalar value assignment to tensor slices (#424)
1 parent 11c9b1d commit bc62cf2

File tree

4 files changed

+75
-31
lines changed

4 files changed

+75
-31
lines changed

helion/_compiler/roll_reduction.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,12 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
101101
return True
102102

103103
if node.target is store:
104-
_, _, stored_node, _ = node.args
105-
assert isinstance(stored_node, torch.fx.Node)
106-
val = stored_node.meta["val"]
104+
_, _, stored_value, _ = node.args
105+
if isinstance(stored_value, torch.fx.Node):
106+
val = stored_value.meta["val"]
107+
else:
108+
# For non-Node values (scalars), they don't have metadata
109+
val = stored_value
107110
else:
108111
val = node.meta["val"]
109112

helion/language/memory_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _(
5353
) -> tuple[
5454
torch.Tensor | tuple,
5555
list[object],
56-
torch.Tensor | torch.SymInt | float,
56+
torch.Tensor | torch.SymInt | float | int,
5757
torch.Tensor | None,
5858
]:
5959
from .tile_proxy import Tile

helion/language/tile_proxy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def __torch_function__(
6969
raise exc.IncorrectTileUsage(func)
7070
tensor, index, value = args
7171
assert isinstance(tensor, torch.Tensor)
72-
assert isinstance(value, torch.Tensor)
72+
# Allow scalars, SymInts, and tensors as values
73+
assert isinstance(value, (torch.Tensor, torch.SymInt, float, int))
7374
return store(tensor, cls._prepare_index(index), value)
7475
if (
7576
func is torch.Tensor.__index__

test/test_indexing.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,6 @@ def kernel(
484484
torch.testing.assert_close(src_result, expected_src)
485485
torch.testing.assert_close(dst_result, expected_dst)
486486

487-
@skipIfNormalMode(
488-
"AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
489-
)
490487
def test_2d_full_slice(self):
491488
"""Test both setter from scalar and getter for [:,:]"""
492489

@@ -537,33 +534,79 @@ def kernel(
537534
torch.testing.assert_close(src_result, expected_src)
538535
torch.testing.assert_close(dst_result, expected_dst)
539536

540-
@skipIfNormalMode(
541-
"AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
542-
)
543537
def test_1d_full_slice(self):
544-
"""Test both setter from scalar and getter for [:]"""
538+
"""Test both setter from scalar and getter for [:] with multiple scalar types"""
545539

546-
@helion.kernel(use_default_config=True)
540+
@helion.kernel(config={"block_size": 128})
547541
def kernel(
548-
src: torch.Tensor, dst: torch.Tensor
549-
) -> tuple[torch.Tensor, torch.Tensor]:
550-
N = src.shape[0]
551-
for _ in hl.grid(N):
552-
dst[:] = 1.0 # Test setter with scalar
553-
src[:] = dst[:] # Test getter from dst and setter to src
554-
return src, dst
542+
src_float: torch.Tensor,
543+
dst_float: torch.Tensor,
544+
src_int: torch.Tensor,
545+
dst_int: torch.Tensor,
546+
src_symint: torch.Tensor,
547+
dst_symint: torch.Tensor,
548+
) -> tuple[
549+
torch.Tensor,
550+
torch.Tensor,
551+
torch.Tensor,
552+
torch.Tensor,
553+
torch.Tensor,
554+
torch.Tensor,
555+
]:
556+
N = src_float.shape[0]
557+
for tile in hl.tile(N):
558+
# Test float scalar
559+
dst_float[:] = 1.0
560+
src_float[:] = dst_float[:]
561+
562+
# Test int scalar
563+
dst_int[:] = 99
564+
src_int[:] = dst_int[:]
565+
566+
# Test SymInt scalar
567+
dst_symint[:] = tile.block_size
568+
src_symint[:] = dst_symint[:]
569+
570+
return (
571+
src_float,
572+
dst_float,
573+
src_int,
574+
dst_int,
575+
src_symint,
576+
dst_symint,
577+
)
555578

556579
N = 128
557-
src = torch.zeros([N], device=DEVICE)
558-
dst = torch.zeros([N], device=DEVICE)
580+
src_float = torch.zeros([N], device=DEVICE)
581+
dst_float = torch.zeros([N], device=DEVICE)
582+
src_int = torch.zeros([N], device=DEVICE)
583+
dst_int = torch.zeros([N], device=DEVICE)
584+
src_symint = torch.zeros([N], device=DEVICE)
585+
dst_symint = torch.zeros([N], device=DEVICE)
586+
587+
results = kernel(
588+
src_float,
589+
dst_float,
590+
src_int,
591+
dst_int,
592+
src_symint,
593+
dst_symint,
594+
)
559595

560-
src_result, dst_result = kernel(src, dst)
596+
# Check float results
597+
expected_float = torch.ones([N], device=DEVICE)
598+
torch.testing.assert_close(results[0], expected_float)
599+
torch.testing.assert_close(results[1], expected_float)
561600

562-
# Both should be ones after the kernel
563-
expected_src = torch.ones([N], device=DEVICE)
564-
expected_dst = torch.ones([N], device=DEVICE)
565-
torch.testing.assert_close(src_result, expected_src)
566-
torch.testing.assert_close(dst_result, expected_dst)
601+
# Check int results
602+
expected_int = torch.full([N], 99.0, device=DEVICE)
603+
torch.testing.assert_close(results[2], expected_int)
604+
torch.testing.assert_close(results[3], expected_int)
605+
606+
# Check SymInt results
607+
expected_symint = torch.full([N], 128.0, device=DEVICE)
608+
torch.testing.assert_close(results[4], expected_symint)
609+
torch.testing.assert_close(results[5], expected_symint)
567610

568611
@skipIfNormalMode(
569612
"RankMismatch: Expected ndim=1, but got ndim=0 - LHS/RHS shape mismatch in type_propagation.py"
@@ -624,9 +667,6 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor:
624667
expected = torch.zeros([N], device=DEVICE)
625668
torch.testing.assert_close(result, expected)
626669

627-
@skipIfNormalMode(
628-
"AssertionError in roll_reduction.py:104 - stored_node is not a torch.fx.Node"
629-
)
630670
def test_mixed_slice_index(self):
631671
"""Test both setter from scalar and getter for [i,:]"""
632672

0 commit comments

Comments
 (0)