|
26 | 26 | from ..runtime.config import Config
|
27 | 27 | from .device_function import TensorDescriptorArg
|
28 | 28 | from .inductor_lowering import CodegenState
|
| 29 | + from .tile_dispatch import TileStrategyDispatch |
29 | 30 |
|
30 | 31 | SymIntLike = torch.SymInt | int
|
31 | 32 | ShapeLike = Sequence[SymIntLike]
|
@@ -61,6 +62,70 @@ def _normalize_negative_index(
|
61 | 62 | return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})"
|
62 | 63 |
|
63 | 64 |
|
| 65 | +def _append_remaining_dimensions( |
| 66 | + input_size: collections.deque, |
| 67 | + output_size: list[int | torch.SymInt], |
| 68 | + env: CompileEnvironment, |
| 69 | +) -> None: |
| 70 | + """Append remaining dimensions from input to output for partial indexing. |
| 71 | +
|
| 72 | + Args: |
| 73 | + input_size: Deque of remaining input dimensions |
| 74 | + output_size: List to append output dimensions to |
| 75 | + env: The compile environment |
| 76 | + """ |
| 77 | + while input_size: |
| 78 | + size = input_size.popleft() |
| 79 | + if size != 1: |
| 80 | + rdim = env.allocate_reduction_dimension(size) |
| 81 | + output_size.append(rdim.var) |
| 82 | + else: |
| 83 | + output_size.append(1) |
| 84 | + |
| 85 | + |
| 86 | +def _handle_remaining_index_dimensions( |
| 87 | + index_values: list[str], |
| 88 | + mask_values: dict[str, None], |
| 89 | + output_size: list[int | torch.SymInt], |
| 90 | + output_idx: int, |
| 91 | + fake_value: torch.Tensor, |
| 92 | + state: CodegenState, |
| 93 | + tile_strategy: TileStrategyDispatch, |
| 94 | + env: CompileEnvironment, |
| 95 | + dtype: str, |
| 96 | +) -> int: |
| 97 | + """Handle remaining dimensions for partial indexing in SubscriptIndexing.create. |
| 98 | +
|
| 99 | + Args: |
| 100 | + index_values: List to append index expressions to |
| 101 | + mask_values: Dict to add mask expressions to |
| 102 | + output_size: The output shape |
| 103 | + output_idx: Current output index |
| 104 | + fake_value: The tensor being indexed |
| 105 | + state: The codegen state |
| 106 | + tile_strategy: The tile strategy |
| 107 | + env: The compile environment |
| 108 | + dtype: The triton index type |
| 109 | +
|
| 110 | + Returns: |
| 111 | + Updated output_idx |
| 112 | + """ |
| 113 | + while len(index_values) < fake_value.ndim: |
| 114 | + expand = tile_strategy.expand_str(output_size, output_idx) |
| 115 | + size = fake_value.size(len(index_values)) |
| 116 | + if size != 1: |
| 117 | + rdim = env.allocate_reduction_dimension(size) |
| 118 | + block_idx = rdim.block_id |
| 119 | + index_var = state.codegen.index_var(block_idx) |
| 120 | + index_values.append(f"({index_var}){expand}") |
| 121 | + if mask := state.codegen.mask_var(block_idx): |
| 122 | + mask_values.setdefault(f"({mask}){expand}") |
| 123 | + else: |
| 124 | + index_values.append(f"tl.zeros([1], {dtype}){expand}") |
| 125 | + output_idx += 1 |
| 126 | + return output_idx |
| 127 | + |
| 128 | + |
64 | 129 | class IndexingStrategy:
|
65 | 130 | def codegen_load(
|
66 | 131 | self,
|
@@ -132,6 +197,32 @@ def codegen_store(
|
132 | 197 | ) -> ast.AST:
|
133 | 198 | indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
|
134 | 199 | name = state.device_function.tensor_arg(fake_tensor).name
|
| 200 | + |
| 201 | + # Check if value is a tensor load (Name node with id matching a tensor arg) |
| 202 | + if isinstance(value, ast.Name) and hasattr( |
| 203 | + state.device_function, "_tensor_args" |
| 204 | + ): |
| 205 | + # Check if this name corresponds to a tensor argument |
| 206 | + tensor = None |
| 207 | + for t, tensor_arg in state.device_function._tensor_args.items(): |
| 208 | + if tensor_arg.name == value.id: |
| 209 | + tensor = t |
| 210 | + break |
| 211 | + |
| 212 | + if tensor is not None: |
| 213 | + # Get the shape of the slice we're storing to |
| 214 | + output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript) |
| 215 | + if len(output_shape) == 1 and tensor.ndim == 1: |
| 216 | + # Load the entire 1D tensor |
| 217 | + value_indexing = SubscriptIndexing.create( |
| 218 | + state, tensor, [slice(None)], None |
| 219 | + ) |
| 220 | + value = expr_from_string( |
| 221 | + f"tl.load({value.id} + offset, mask)", |
| 222 | + offset=value_indexing.index_expr, |
| 223 | + mask=value_indexing.mask_expr, |
| 224 | + ) |
| 225 | + |
135 | 226 | return expr_from_string(
|
136 | 227 | f"tl.store({name} + offset, value, mask)",
|
137 | 228 | value=value,
|
@@ -503,7 +594,9 @@ def compute_shape(
|
503 | 594 | ) -> list[int | torch.SymInt]:
|
504 | 595 | assert isinstance(tensor, torch.Tensor)
|
505 | 596 | assert isinstance(index, (list, tuple)), index
|
506 |
| - input_size = collections.deque(tensor.size()) |
| 597 | + input_size: collections.deque[int | torch.SymInt] = collections.deque( |
| 598 | + tensor.size() |
| 599 | + ) |
507 | 600 | output_size = []
|
508 | 601 | env = CompileEnvironment.current()
|
509 | 602 | for i, k in enumerate(index):
|
@@ -547,7 +640,8 @@ def compute_shape(
|
547 | 640 | output_size.extend(k.size())
|
548 | 641 | else:
|
549 | 642 | raise exc.InvalidIndexingType(k)
|
550 |
| - assert len(input_size) == 0, "invalid subscript" |
| 643 | + # For partial indexing, append remaining dimensions to output |
| 644 | + _append_remaining_dimensions(input_size, output_size, env) |
551 | 645 | return output_size
|
552 | 646 |
|
553 | 647 | @staticmethod
|
@@ -675,6 +769,20 @@ def create(
|
675 | 769 | )
|
676 | 770 | else:
|
677 | 771 | raise exc.InvalidIndexingType(type(k))
|
| 772 | + |
| 773 | + # Handle remaining dimensions for partial indexing |
| 774 | + output_idx = _handle_remaining_index_dimensions( |
| 775 | + index_values, |
| 776 | + mask_values, |
| 777 | + output_size, |
| 778 | + output_idx, |
| 779 | + fake_value, |
| 780 | + state, |
| 781 | + tile_strategy, |
| 782 | + env, |
| 783 | + dtype, |
| 784 | + ) |
| 785 | + |
678 | 786 | assert len(output_size) == output_idx
|
679 | 787 | assert len(index_values) == fake_value.ndim
|
680 | 788 | index_expr = []
|
@@ -800,7 +908,9 @@ def is_supported(
|
800 | 908 | if extra_mask is not None:
|
801 | 909 | # TODO(jansel): support block_ptr with extra_mask
|
802 | 910 | return False
|
803 |
| - input_sizes = collections.deque(fake_tensor.size()) |
| 911 | + input_sizes: collections.deque[int | torch.SymInt] = collections.deque( |
| 912 | + fake_tensor.size() |
| 913 | + ) |
804 | 914 | for n, k in enumerate(index):
|
805 | 915 | if k is None:
|
806 | 916 | input_size = 1
|
|
0 commit comments