Skip to content

Commit 4718678

Browse files
authored
Fix strided slice support for static slices (e.g., buf[::2]) (#426)
1 parent a252cca commit 4718678

File tree

4 files changed

+89
-19
lines changed

4 files changed

+89
-19
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .device_function import DeviceFunction
1818
from .host_function import HostFunction
1919
from .tile_strategy import DeviceLoopState
20+
from .utils import compute_slice_size
2021
from .variable_origin import BlockSizeOrigin
2122

2223
if TYPE_CHECKING:
@@ -227,7 +228,10 @@ def valid_block_size(
227228
if k is None:
228229
continue
229230
size, stride = size_stride.popleft()
230-
if str(k) == "slice(None, None, None)":
231+
if isinstance(k, slice):
232+
# Slices with steps are not supported in tensor descriptor mode
233+
if k.step is not None and k.step != 1:
234+
return False
231235
block_size = env.allocate_reduction_dimension(size).from_config(config)
232236
if not valid_block_size(block_size, stride, i):
233237
return False
@@ -476,10 +480,13 @@ def compute_shape(
476480
output_size.append(k)
477481
else:
478482
output_size.append(1)
479-
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
483+
elif isinstance(k, slice):
480484
size = input_size.popleft()
481-
if size != 1:
482-
rdim = env.allocate_reduction_dimension(size)
485+
# Handle slices with steps
486+
slice_size = compute_slice_size(k, size)
487+
488+
if slice_size != 1:
489+
rdim = env.allocate_reduction_dimension(slice_size)
483490
output_size.append(rdim.var)
484491
else:
485492
output_size.append(1)
@@ -531,18 +538,40 @@ def create(
531538
# When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated.
532539
val = state.device_function.literal_expr(k)
533540
index_values.append(f"({val})")
534-
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
541+
elif isinstance(k, slice):
535542
expand = tile_strategy.expand_str(output_size, output_idx)
536543
size = fake_value.size(len(index_values))
537-
if size != 1:
538-
rdim = env.allocate_reduction_dimension(size)
539-
block_idx = rdim.block_id
540-
index_var = state.codegen.index_var(block_idx)
541-
index_values.append(f"({index_var}){expand}")
542-
if mask := state.codegen.mask_var(block_idx):
543-
mask_values.setdefault(f"({mask}){expand}")
544+
545+
# Handle slices with steps
546+
if k.step is not None and k.step != 1:
547+
# For strided slices, we need to generate: start + index * step
548+
start = k.start if k.start is not None else 0
549+
step = k.step
550+
slice_size = compute_slice_size(k, size)
551+
552+
if slice_size != 1:
553+
rdim = env.allocate_reduction_dimension(slice_size)
554+
block_idx = rdim.block_id
555+
index_var = state.codegen.index_var(block_idx)
556+
# Generate strided index: start + index * step
557+
index_values.append(
558+
f"({start} + ({index_var}) * {step}){expand}"
559+
)
560+
if mask := state.codegen.mask_var(block_idx):
561+
mask_values.setdefault(f"({mask}){expand}")
562+
else:
563+
index_values.append(f"{start}{expand}")
544564
else:
545-
index_values.append(f"tl.zeros([1], {dtype}){expand}")
565+
# Full slice or slice without step
566+
if size != 1:
567+
rdim = env.allocate_reduction_dimension(size)
568+
block_idx = rdim.block_id
569+
index_var = state.codegen.index_var(block_idx)
570+
index_values.append(f"({index_var}){expand}")
571+
if mask := state.codegen.mask_var(block_idx):
572+
mask_values.setdefault(f"({mask}){expand}")
573+
else:
574+
index_values.append(f"tl.zeros([1], {dtype}){expand}")
546575
output_idx += 1
547576
elif isinstance(k, torch.Tensor) and k.ndim == 1:
548577
expand = tile_strategy.expand_str(output_size, output_idx)
@@ -772,8 +801,15 @@ def create(
772801
else:
773802
res.offsets.append(state.device_function.literal_expr(k))
774803
res.block_shape.append(1)
775-
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
804+
elif isinstance(k, slice):
776805
size = fake_value.size(len(res.offsets))
806+
# Handle slices with steps
807+
if k.step is not None and k.step != 1:
808+
# Slices with steps are not supported in block_ptr mode
809+
raise exc.InvalidIndexingType(
810+
f"Strided slices not supported in block_ptr mode: {k}"
811+
)
812+
# Full slice or slice without step
777813
if size != 1:
778814
env = CompileEnvironment.current()
779815
rdim = env.allocate_reduction_dimension(size)

helion/_compiler/type_propagation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .host_function import SymbolOrigin
4343
from .output_header import library_imports
4444
from .source_location import current_location
45+
from .utils import compute_slice_size
4546
from .variable_origin import ArgumentOrigin
4647
from .variable_origin import AttributeOrigin
4748
from .variable_origin import BuiltinOrigin
@@ -437,14 +438,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
437438
elif isinstance(k, SymIntType):
438439
inputs_consumed += 1
439440
elif isinstance(k, SliceType):
440-
assert str(k.proxy()) == "slice(None, None, None)"
441+
# Handle slices - including those with steps
442+
slice_obj = k.proxy()
441443
size = self.fake_value.size(inputs_consumed)
442444
inputs_consumed += 1
445+
446+
# For slices with steps, we need to calculate the output size differently
447+
output_size = compute_slice_size(slice_obj, size)
448+
443449
if self.origin.is_device():
444-
output_sizes.append(size)
445-
elif size != 1:
450+
output_sizes.append(output_size)
451+
elif output_size != 1:
446452
rdim = CompileEnvironment.current().allocate_reduction_dimension(
447-
size
453+
output_size
448454
)
449455
output_sizes.append(rdim.var)
450456
else:

helion/_compiler/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
import torch
7+
8+
9+
def compute_slice_size(
10+
slice_obj: slice, original_size: int | torch.SymInt
11+
) -> int | torch.SymInt:
12+
"""
13+
Compute the size of a slice operation.
14+
15+
Args:
16+
slice_obj: The slice object with start, stop, and step attributes
17+
original_size: The size of the dimension being sliced
18+
19+
Returns:
20+
The size of the resulting sliced dimension
21+
"""
22+
if slice_obj.step is not None and slice_obj.step != 1:
23+
# Calculate size based on step
24+
start = slice_obj.start if slice_obj.start is not None else 0
25+
stop = slice_obj.stop if slice_obj.stop is not None else original_size
26+
step = slice_obj.step
27+
return (stop - start + step - 1) // step
28+
# Full slice or slice without step
29+
return original_size

test/test_indexing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,6 @@ def kernel(
689689
torch.testing.assert_close(src_result, expected_src)
690690
torch.testing.assert_close(dst_result, expected_dst)
691691

692-
@skipIfNormalMode("InternalError: AssertionError")
693692
def test_strided_slice(self):
694693
"""Test both setter from scalar and getter for strided slices [::2] and [1::3]"""
695694

0 commit comments

Comments
 (0)