|
17 | 17 | from .device_function import DeviceFunction
|
18 | 18 | from .host_function import HostFunction
|
19 | 19 | from .tile_strategy import DeviceLoopState
|
| 20 | +from .utils import compute_slice_size |
20 | 21 | from .variable_origin import BlockSizeOrigin
|
21 | 22 |
|
22 | 23 | if TYPE_CHECKING:
|
@@ -227,7 +228,10 @@ def valid_block_size(
|
227 | 228 | if k is None:
|
228 | 229 | continue
|
229 | 230 | size, stride = size_stride.popleft()
|
230 |
| - if str(k) == "slice(None, None, None)": |
| 231 | + if isinstance(k, slice): |
| 232 | + # Slices with steps are not supported in tensor descriptor mode |
| 233 | + if k.step is not None and k.step != 1: |
| 234 | + return False |
231 | 235 | block_size = env.allocate_reduction_dimension(size).from_config(config)
|
232 | 236 | if not valid_block_size(block_size, stride, i):
|
233 | 237 | return False
|
@@ -476,10 +480,13 @@ def compute_shape(
|
476 | 480 | output_size.append(k)
|
477 | 481 | else:
|
478 | 482 | output_size.append(1)
|
479 |
| - elif isinstance(k, slice) and str(k) == "slice(None, None, None)": |
| 483 | + elif isinstance(k, slice): |
480 | 484 | size = input_size.popleft()
|
481 |
| - if size != 1: |
482 |
| - rdim = env.allocate_reduction_dimension(size) |
| 485 | + # Handle slices with steps |
| 486 | + slice_size = compute_slice_size(k, size) |
| 487 | + |
| 488 | + if slice_size != 1: |
| 489 | + rdim = env.allocate_reduction_dimension(slice_size) |
483 | 490 | output_size.append(rdim.var)
|
484 | 491 | else:
|
485 | 492 | output_size.append(1)
|
@@ -531,18 +538,40 @@ def create(
|
531 | 538 | # When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated.
|
532 | 539 | val = state.device_function.literal_expr(k)
|
533 | 540 | index_values.append(f"({val})")
|
534 |
| - elif isinstance(k, slice) and str(k) == "slice(None, None, None)": |
| 541 | + elif isinstance(k, slice): |
535 | 542 | expand = tile_strategy.expand_str(output_size, output_idx)
|
536 | 543 | size = fake_value.size(len(index_values))
|
537 |
| - if size != 1: |
538 |
| - rdim = env.allocate_reduction_dimension(size) |
539 |
| - block_idx = rdim.block_id |
540 |
| - index_var = state.codegen.index_var(block_idx) |
541 |
| - index_values.append(f"({index_var}){expand}") |
542 |
| - if mask := state.codegen.mask_var(block_idx): |
543 |
| - mask_values.setdefault(f"({mask}){expand}") |
| 544 | + |
| 545 | + # Handle slices with steps |
| 546 | + if k.step is not None and k.step != 1: |
| 547 | + # For strided slices, we need to generate: start + index * step |
| 548 | + start = k.start if k.start is not None else 0 |
| 549 | + step = k.step |
| 550 | + slice_size = compute_slice_size(k, size) |
| 551 | + |
| 552 | + if slice_size != 1: |
| 553 | + rdim = env.allocate_reduction_dimension(slice_size) |
| 554 | + block_idx = rdim.block_id |
| 555 | + index_var = state.codegen.index_var(block_idx) |
| 556 | + # Generate strided index: start + index * step |
| 557 | + index_values.append( |
| 558 | + f"({start} + ({index_var}) * {step}){expand}" |
| 559 | + ) |
| 560 | + if mask := state.codegen.mask_var(block_idx): |
| 561 | + mask_values.setdefault(f"({mask}){expand}") |
| 562 | + else: |
| 563 | + index_values.append(f"{start}{expand}") |
544 | 564 | else:
|
545 |
| - index_values.append(f"tl.zeros([1], {dtype}){expand}") |
| 565 | + # Full slice or slice without step |
| 566 | + if size != 1: |
| 567 | + rdim = env.allocate_reduction_dimension(size) |
| 568 | + block_idx = rdim.block_id |
| 569 | + index_var = state.codegen.index_var(block_idx) |
| 570 | + index_values.append(f"({index_var}){expand}") |
| 571 | + if mask := state.codegen.mask_var(block_idx): |
| 572 | + mask_values.setdefault(f"({mask}){expand}") |
| 573 | + else: |
| 574 | + index_values.append(f"tl.zeros([1], {dtype}){expand}") |
546 | 575 | output_idx += 1
|
547 | 576 | elif isinstance(k, torch.Tensor) and k.ndim == 1:
|
548 | 577 | expand = tile_strategy.expand_str(output_size, output_idx)
|
@@ -772,8 +801,15 @@ def create(
|
772 | 801 | else:
|
773 | 802 | res.offsets.append(state.device_function.literal_expr(k))
|
774 | 803 | res.block_shape.append(1)
|
775 |
| - elif isinstance(k, slice) and str(k) == "slice(None, None, None)": |
| 804 | + elif isinstance(k, slice): |
776 | 805 | size = fake_value.size(len(res.offsets))
|
| 806 | + # Handle slices with steps |
| 807 | + if k.step is not None and k.step != 1: |
| 808 | + # Slices with steps are not supported in block_ptr mode |
| 809 | + raise exc.InvalidIndexingType( |
| 810 | + f"Strided slices not supported in block_ptr mode: {k}" |
| 811 | + ) |
| 812 | + # Full slice or slice without step |
777 | 813 | if size != 1:
|
778 | 814 | env = CompileEnvironment.current()
|
779 | 815 | rdim = env.allocate_reduction_dimension(size)
|
|
0 commit comments