Skip to content

Commit c445092

Browse files
committed
Fix negative indexing and multi-dimensional slicing in Helion
stack-info: PR: #438, branch: yf225/stack/54
1 parent ef712a0 commit c445092

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,36 @@
3131
ShapeLike = Sequence[SymIntLike]
3232

3333

34+
def _normalize_negative_index(
35+
k: int,
36+
dim_idx: int,
37+
fake_value: torch.Tensor,
38+
state: CodegenState,
39+
) -> str:
40+
"""Normalize negative indices to positive ones.
41+
42+
Args:
43+
k: The negative index value
44+
dim_idx: The dimension index
45+
fake_value: The fake tensor to get dimension size from
46+
state: The codegen state
47+
48+
Returns:
49+
String representation of the normalized index
50+
"""
51+
assert k < 0, "This function should only be called for negative indices"
52+
53+
dim_size = fake_value.size(dim_idx)
54+
# Handle both concrete and symbolic dimension sizes
55+
if isinstance(dim_size, int):
56+
normalized_k = k + dim_size
57+
return repr(normalized_k)
58+
# For symbolic dimensions, we need to generate the proper expression
59+
# The state.codegen is a GenerateAST instance which has device_function
60+
sympy_expr = dim_size._sympy_() + k
61+
return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})"
62+
63+
3464
class IndexingStrategy:
3565
def codegen_load(
3666
self,
@@ -555,7 +585,14 @@ def create(
555585
index_values.append(f"tl.zeros([1], {dtype}){expand}")
556586
output_idx += 1
557587
elif isinstance(k, int):
558-
index_values.append(repr(k))
588+
# Normalize negative indices
589+
if k < 0:
590+
dim_idx = len(index_values)
591+
index_values.append(
592+
_normalize_negative_index(k, dim_idx, fake_value, state)
593+
)
594+
else:
595+
index_values.append(repr(k))
559596
elif isinstance(k, torch.SymInt):
560597
symbol = k._sympy_()
561598
origin = None
@@ -843,7 +880,14 @@ def create(
843880
res.offsets.append("0")
844881
res.block_shape.append(1)
845882
elif isinstance(k, int):
846-
res.offsets.append(repr(k))
883+
# Normalize negative indices
884+
if k < 0:
885+
dim_idx = len(res.offsets)
886+
res.offsets.append(
887+
_normalize_negative_index(k, dim_idx, fake_value, state)
888+
)
889+
else:
890+
res.offsets.append(repr(k))
847891
res.block_shape.append(1)
848892
elif isinstance(k, torch.SymInt):
849893
symbol = k._sympy_()

test/test_indexing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ def kernel(
733733
torch.testing.assert_close(src2_result, expected_src2)
734734
torch.testing.assert_close(dst2_result, expected_dst2)
735735

736-
@skipIfNormalMode("InternalError: Negative indexes")
737736
def test_negative_indexing(self):
738737
"""Test both setter from scalar and getter for [-1]"""
739738

@@ -784,9 +783,6 @@ def kernel(
784783
torch.testing.assert_close(src_result, expected_src)
785784
torch.testing.assert_close(dst_result, expected_dst)
786785

787-
@skipIfNormalMode(
788-
"RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3"
789-
)
790786
def test_multi_dim_slice(self):
791787
"""Test both setter from scalar and getter for [:, :, i]"""
792788

0 commit comments

Comments
 (0)