Skip to content

Commit f8e3baf

Browse files
committed
Fix negative indexing and multi-dimensional slicing in Helion
1 parent 9ca2d70 commit f8e3baf

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,21 @@ def create(
549549
index_values.append(f"tl.zeros([1], {dtype}){expand}")
550550
output_idx += 1
551551
elif isinstance(k, int):
552-
index_values.append(repr(k))
552+
# Normalize negative indices
553+
if k < 0:
554+
dim_idx = len(index_values)
555+
dim_size = fake_value.size(dim_idx)
556+
# Handle both concrete and symbolic dimension sizes
557+
if isinstance(dim_size, int):
558+
normalized_k = k + dim_size
559+
index_values.append(repr(normalized_k))
560+
else:
561+
# For symbolic dimensions, we need to generate the proper expression
562+
# The state.codegen is a GenerateAST instance which has device_function
563+
sympy_expr = dim_size._sympy_() + k
564+
index_values.append(f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})")
565+
else:
566+
index_values.append(repr(k))
553567
elif isinstance(k, torch.SymInt):
554568
symbol = k._sympy_()
555569
origin = None
@@ -839,7 +853,21 @@ def create(
839853
res.offsets.append("0")
840854
res.block_shape.append(1)
841855
elif isinstance(k, int):
842-
res.offsets.append(repr(k))
856+
# Normalize negative indices
857+
if k < 0:
858+
dim_idx = len(res.offsets)
859+
dim_size = fake_value.size(dim_idx)
860+
# Handle both concrete and symbolic dimension sizes
861+
if isinstance(dim_size, int):
862+
normalized_k = k + dim_size
863+
res.offsets.append(repr(normalized_k))
864+
else:
865+
# For symbolic dimensions, we need to generate the proper expression
866+
# The state.codegen is a GenerateAST instance which has device_function
867+
sympy_expr = dim_size._sympy_() + k
868+
res.offsets.append(f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})")
869+
else:
870+
res.offsets.append(repr(k))
843871
res.block_shape.append(1)
844872
elif isinstance(k, torch.SymInt):
845873
symbol = k._sympy_()

test/test_indexing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ def kernel(
733733
torch.testing.assert_close(src2_result, expected_src2)
734734
torch.testing.assert_close(dst2_result, expected_dst2)
735735

736-
@skipIfNormalMode("InternalError: Negative indexes")
737736
def test_negative_indexing(self):
738737
"""Test both setter from scalar and getter for [-1]"""
739738

@@ -784,9 +783,6 @@ def kernel(
784783
torch.testing.assert_close(src_result, expected_src)
785784
torch.testing.assert_close(dst_result, expected_dst)
786785

787-
@skipIfNormalMode(
788-
"RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3"
789-
)
790786
def test_multi_dim_slice(self):
791787
"""Test both setter from scalar and getter for [:, :, i]"""
792788

0 commit comments

Comments
 (0)