Skip to content

Commit 2c1ed69

Browse files
committed
Fix static slice indexing with explicit start/stop bounds
stack-info: PR: #440, branch: yf225/stack/56
1 parent 128393a commit 2c1ed69

File tree

4 files changed

+40
-17
lines changed

4 files changed

+40
-17
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ def compute_shape(
517517
size = input_size.popleft()
518518
# Handle slices with steps
519519
slice_size = compute_slice_size(k, size)
520-
521520
if slice_size != 1:
522521
rdim = env.allocate_reduction_dimension(slice_size)
523522
output_size.append(rdim.var)
@@ -633,16 +632,24 @@ def create(
633632
else:
634633
index_values.append(f"{start}{expand}")
635634
else:
636-
# Full slice or slice without step
637-
if size != 1:
638-
rdim = env.allocate_reduction_dimension(size)
635+
# Handle slices with start/stop but no step
636+
start = k.start if k.start is not None else 0
637+
stop = k.stop if k.stop is not None else size
638+
slice_size = stop - start
639+
640+
if slice_size != 1:
641+
rdim = env.allocate_reduction_dimension(slice_size)
639642
block_idx = rdim.block_id
640643
index_var = state.codegen.index_var(block_idx)
641-
index_values.append(f"({index_var}){expand}")
644+
# Generate index: start + index_var
645+
if start != 0:
646+
index_values.append(f"({start} + ({index_var})){expand}")
647+
else:
648+
index_values.append(f"({index_var}){expand}")
642649
if mask := state.codegen.mask_var(block_idx):
643650
mask_values.setdefault(f"({mask}){expand}")
644651
else:
645-
index_values.append(f"tl.zeros([1], {dtype}){expand}")
652+
index_values.append(f"{start}{expand}")
646653
output_idx += 1
647654
elif isinstance(k, torch.Tensor) and k.ndim == 1:
648655
expand = tile_strategy.expand_str(output_size, output_idx)
@@ -941,8 +948,24 @@ def create(
941948
res.offsets.append(state.codegen.offset_var(rdim.block_id))
942949
res.block_shape.append(rdim.var)
943950
else:
944-
res.offsets.append("0")
945-
res.block_shape.append(1)
951+
# Handle slices with start/stop but no step
952+
start = k.start if k.start is not None else 0
953+
stop = k.stop if k.stop is not None else size
954+
slice_size = stop - start
955+
956+
if slice_size != 1:
957+
env = CompileEnvironment.current()
958+
rdim = env.allocate_reduction_dimension(slice_size)
959+
offset = state.codegen.offset_var(rdim.block_id)
960+
# Add start offset if needed
961+
if start != 0:
962+
res.offsets.append(f"({start} + {offset})")
963+
else:
964+
res.offsets.append(offset)
965+
res.block_shape.append(rdim.var)
966+
else:
967+
res.offsets.append(str(start))
968+
res.block_shape.append(1)
946969
else:
947970
raise exc.InvalidIndexingType(k)
948971
res.validate()

helion/_compiler/type_propagation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,6 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
467467

468468
# For slices with steps, we need to calculate the output size differently
469469
output_size = compute_slice_size(slice_obj, size)
470-
471470
if self.origin.is_device():
472471
output_sizes.append(output_size)
473472
elif output_size != 1:
@@ -510,8 +509,9 @@ def propagate_setitem(
510509
lhs_rank = len(lhs_shape)
511510
if isinstance(value, TensorType):
512511
rhs_rank = value.fake_value.ndim
513-
# Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts)
514-
if rhs_rank != 0 and lhs_rank != rhs_rank:
512+
rhs_numel = value.fake_value.numel()
513+
# Allow scalar tensors (rank 0) or single-element tensors to be assigned to any rank (broadcasts)
514+
if rhs_rank != 0 and rhs_numel != 1 and lhs_rank != rhs_rank:
515515
raise exc.RankMismatch(
516516
lhs_rank,
517517
rhs_rank,

helion/_compiler/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,8 @@ def compute_slice_size(
2525
stop = slice_obj.stop if slice_obj.stop is not None else original_size
2626
step = slice_obj.step
2727
return (stop - start + step - 1) // step
28-
# Full slice or slice without step
29-
return original_size
28+
else:
29+
# Calculate slice size based on start/stop
30+
start = slice_obj.start if slice_obj.start is not None else 0
31+
stop = slice_obj.stop if slice_obj.stop is not None else original_size
32+
return stop - start

test/test_indexing.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,6 @@ def kernel(
877877
torch.testing.assert_close(src_result, expected_src)
878878
torch.testing.assert_close(dst_result, expected_dst)
879879

880-
@skipIfNormalMode("InternalError: Unexpected type <class 'slice'>")
881880
def test_range_slice(self):
882881
"""Test both setter from scalar and getter for [10:20]"""
883882

@@ -903,9 +902,7 @@ def kernel(
903902
torch.testing.assert_close(src_result, expected_src)
904903
torch.testing.assert_close(dst_result, expected_dst)
905904

906-
@skipIfNormalMode(
907-
"InternalError: AssertionError in type_propagation.py - slice indexing error"
908-
)
905+
@skipIfNormalMode("Dynamic slices (i:i+1) are not supported - FX cannot trace symbolic slice indices")
909906
def test_range_slice_dynamic(self):
910907
"""Test both [i:i+1] = scalar and [i] = [i:i+1] patterns"""
911908

0 commit comments

Comments
 (0)