diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index bb7491ce..c41f2d3f 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -18,6 +18,7 @@ from .host_function import HostFunction from .tile_strategy import DeviceLoopState from .utils import compute_slice_size +from .utils import get_slice_start from .variable_origin import BlockSizeOrigin if TYPE_CHECKING: @@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions( return output_idx +def _generate_slice_index( + start: int | torch.SymInt, + index_var: str, + expand: str, + step: int | None = None, +) -> str: + """Generate slice index expression with optional step.""" + if step is not None: + # Strided index: start + index * step + return f"({start} + ({index_var}) * {step}){expand}" + if start != 0: + # Index with offset: start + index + return f"({start} + ({index_var})){expand}" + # Simple index + return f"({index_var}){expand}" + + +def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str: + """Generate offset expression with optional start.""" + if start != 0: + return f"({start} + {offset})" + return offset + + class IndexingStrategy: def codegen_load( self, @@ -627,7 +652,6 @@ def compute_shape( size = input_size.popleft() # Handle slices with steps slice_size = compute_slice_size(k, size) - if slice_size != 1: rdim = env.allocate_reduction_dimension(slice_size) output_size.append(rdim.var) @@ -719,25 +743,29 @@ def create( rdim = env.allocate_reduction_dimension(slice_size) block_idx = rdim.block_id index_var = state.codegen.index_var(block_idx) - # Generate strided index: start + index * step index_values.append( - f"({start} + ({index_var}) * {step}){expand}" + _generate_slice_index(start, index_var, expand, step) ) if mask := state.codegen.mask_var(block_idx): mask_values.setdefault(f"({mask}){expand}") else: index_values.append(f"{start}{expand}") else: - # Full slice or slice without step - if size != 1: - rdim = env.allocate_reduction_dimension(size) + # Handle slices with start/stop but no step + start = get_slice_start(k) + slice_size = compute_slice_size(k, size) + + if slice_size != 1: + rdim = env.allocate_reduction_dimension(slice_size) block_idx = rdim.block_id index_var = state.codegen.index_var(block_idx) - index_values.append(f"({index_var}){expand}") + index_values.append( + _generate_slice_index(start, index_var, expand) + ) if mask := state.codegen.mask_var(block_idx): mask_values.setdefault(f"({mask}){expand}") else: - index_values.append(f"tl.zeros([1], {dtype}){expand}") + index_values.append(f"{start}{expand}") output_idx += 1 elif isinstance(k, torch.Tensor) and k.ndim == 1: expand = tile_strategy.expand_str(output_size, output_idx) @@ -1025,8 +1053,19 @@ def create( res.offsets.append(state.codegen.offset_var(rdim.block_id)) res.block_shape.append(rdim.var) else: - res.offsets.append("0") - res.block_shape.append(1) + # Handle slices with start/stop but no step + start = get_slice_start(k) + slice_size = compute_slice_size(k, size) + + if slice_size != 1: + env = CompileEnvironment.current() + rdim = env.allocate_reduction_dimension(slice_size) + offset = state.codegen.offset_var(rdim.block_id) + res.offsets.append(_generate_offset_expr(start, offset)) + res.block_shape.append(rdim.var) + else: + res.offsets.append(str(start)) + res.block_shape.append(1) else: raise exc.InvalidIndexingType(k) res.validate() diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 9355fefb..96d673b8 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -466,7 +466,6 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: # For slices with steps, we need to calculate the output size differently output_size = compute_slice_size(slice_obj, size) - if self.origin.is_device(): output_sizes.append(output_size) elif output_size != 1: @@ -515,8 +514,9 @@ def propagate_setitem( lhs_rank = len(lhs_shape) if isinstance(value, TensorType): rhs_rank = value.fake_value.ndim - # Allow scalar tensors (rank 0) to be assigned to any rank (broadcasts) - if rhs_rank != 0 and lhs_rank != rhs_rank: + rhs_numel = value.fake_value.numel() + # Allow scalar tensors (rank 0) or single-element tensors to be assigned to any rank (broadcasts) + if rhs_rank != 0 and rhs_numel != 1 and lhs_rank != rhs_rank: raise exc.RankMismatch( lhs_rank, rhs_rank, diff --git a/helion/_compiler/utils.py b/helion/_compiler/utils.py index f5260cd7..a7145ec3 100644 --- a/helion/_compiler/utils.py +++ b/helion/_compiler/utils.py @@ -25,5 +25,12 @@ def compute_slice_size( stop = slice_obj.stop if slice_obj.stop is not None else original_size step = slice_obj.step return (stop - start + step - 1) // step - # Full slice or slice without step - return original_size + # Calculate slice size based on start/stop + start = slice_obj.start if slice_obj.start is not None else 0 + stop = slice_obj.stop if slice_obj.stop is not None else original_size + return stop - start + + +def get_slice_start(slice_obj: slice) -> int: + """Get the start index of a slice, defaulting to 0.""" + return slice_obj.start if slice_obj.start is not None else 0 diff --git a/test/test_indexing.py b/test/test_indexing.py index f63cde16..1a48f570 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -877,7 +877,6 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) - @skipIfNormalMode("InternalError: Unexpected type ") def test_range_slice(self): """Test both setter from scalar and getter for [10:20]""" @@ -904,7 +903,7 @@ def kernel( torch.testing.assert_close(dst_result, expected_dst) @skipIfNormalMode( - "InternalError: AssertionError in type_propagation.py - slice indexing error" + "Dynamic slices (i:i+1) are not supported - FX cannot trace symbolic slice indices" ) def test_range_slice_dynamic(self): """Test both [i:i+1] = scalar and [i] = [i:i+1] patterns"""