|
18 | 18 | from .host_function import HostFunction
|
19 | 19 | from .tile_strategy import DeviceLoopState
|
20 | 20 | from .utils import compute_slice_size
|
| 21 | +from .utils import get_slice_start |
21 | 22 | from .variable_origin import BlockSizeOrigin
|
22 | 23 |
|
23 | 24 | if TYPE_CHECKING:
|
@@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions(
|
126 | 127 | return output_idx
|
127 | 128 |
|
128 | 129 |
|
| 130 | +def _generate_slice_index( |
| 131 | + start: int | torch.SymInt, |
| 132 | + index_var: str, |
| 133 | + expand: str, |
| 134 | + step: int | None = None, |
| 135 | +) -> str: |
| 136 | + """Generate slice index expression with optional step.""" |
| 137 | + if step is not None: |
| 138 | + # Strided index: start + index * step |
| 139 | + return f"({start} + ({index_var}) * {step}){expand}" |
| 140 | + if start != 0: |
| 141 | + # Index with offset: start + index |
| 142 | + return f"({start} + ({index_var})){expand}" |
| 143 | + # Simple index |
| 144 | + return f"({index_var}){expand}" |
| 145 | + |
| 146 | + |
| 147 | +def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str: |
| 148 | + """Generate offset expression with optional start.""" |
| 149 | + if start != 0: |
| 150 | + return f"({start} + {offset})" |
| 151 | + return offset |
| 152 | + |
| 153 | + |
129 | 154 | class IndexingStrategy:
|
130 | 155 | def codegen_load(
|
131 | 156 | self,
|
@@ -627,7 +652,6 @@ def compute_shape(
|
627 | 652 | size = input_size.popleft()
|
628 | 653 | # Handle slices with steps
|
629 | 654 | slice_size = compute_slice_size(k, size)
|
630 |
| - |
631 | 655 | if slice_size != 1:
|
632 | 656 | rdim = env.allocate_reduction_dimension(slice_size)
|
633 | 657 | output_size.append(rdim.var)
|
@@ -719,25 +743,29 @@ def create(
|
719 | 743 | rdim = env.allocate_reduction_dimension(slice_size)
|
720 | 744 | block_idx = rdim.block_id
|
721 | 745 | index_var = state.codegen.index_var(block_idx)
|
722 |
| - # Generate strided index: start + index * step |
723 | 746 | index_values.append(
|
724 |
| - f"({start} + ({index_var}) * {step}){expand}" |
| 747 | + _generate_slice_index(start, index_var, expand, step) |
725 | 748 | )
|
726 | 749 | if mask := state.codegen.mask_var(block_idx):
|
727 | 750 | mask_values.setdefault(f"({mask}){expand}")
|
728 | 751 | else:
|
729 | 752 | index_values.append(f"{start}{expand}")
|
730 | 753 | else:
|
731 |
| - # Full slice or slice without step |
732 |
| - if size != 1: |
733 |
| - rdim = env.allocate_reduction_dimension(size) |
| 754 | + # Handle slices with start/stop but no step |
| 755 | + start = get_slice_start(k) |
| 756 | + slice_size = compute_slice_size(k, size) |
| 757 | + |
| 758 | + if slice_size != 1: |
| 759 | + rdim = env.allocate_reduction_dimension(slice_size) |
734 | 760 | block_idx = rdim.block_id
|
735 | 761 | index_var = state.codegen.index_var(block_idx)
|
736 |
| - index_values.append(f"({index_var}){expand}") |
| 762 | + index_values.append( |
| 763 | + _generate_slice_index(start, index_var, expand) |
| 764 | + ) |
737 | 765 | if mask := state.codegen.mask_var(block_idx):
|
738 | 766 | mask_values.setdefault(f"({mask}){expand}")
|
739 | 767 | else:
|
740 |
| - index_values.append(f"tl.zeros([1], {dtype}){expand}") |
| 768 | + index_values.append(f"{start}{expand}") |
741 | 769 | output_idx += 1
|
742 | 770 | elif isinstance(k, torch.Tensor) and k.ndim == 1:
|
743 | 771 | expand = tile_strategy.expand_str(output_size, output_idx)
|
@@ -1025,8 +1053,19 @@ def create(
|
1025 | 1053 | res.offsets.append(state.codegen.offset_var(rdim.block_id))
|
1026 | 1054 | res.block_shape.append(rdim.var)
|
1027 | 1055 | else:
|
1028 |
| - res.offsets.append("0") |
1029 |
| - res.block_shape.append(1) |
| 1056 | + # Handle slices with start/stop but no step |
| 1057 | + start = get_slice_start(k) |
| 1058 | + slice_size = compute_slice_size(k, size) |
| 1059 | + |
| 1060 | + if slice_size != 1: |
| 1061 | + env = CompileEnvironment.current() |
| 1062 | + rdim = env.allocate_reduction_dimension(slice_size) |
| 1063 | + offset = state.codegen.offset_var(rdim.block_id) |
| 1064 | + res.offsets.append(_generate_offset_expr(start, offset)) |
| 1065 | + res.block_shape.append(rdim.var) |
| 1066 | + else: |
| 1067 | + res.offsets.append(str(start)) |
| 1068 | + res.block_shape.append(1) |
1030 | 1069 | else:
|
1031 | 1070 | raise exc.InvalidIndexingType(k)
|
1032 | 1071 | res.validate()
|
|
0 commit comments