|
31 | 31 | ShapeLike = Sequence[SymIntLike]
|
32 | 32 |
|
33 | 33 |
|
| 34 | +def _normalize_negative_index( |
| 35 | + k: int, |
| 36 | + dim_idx: int, |
| 37 | + fake_value: torch.Tensor, |
| 38 | + state: CodegenState, |
| 39 | +) -> str: |
| 40 | + """Normalize negative indices to positive ones. |
| 41 | +
|
| 42 | + Args: |
| 43 | + k: The negative index value |
| 44 | + dim_idx: The dimension index |
| 45 | + fake_value: The fake tensor to get dimension size from |
| 46 | + state: The codegen state |
| 47 | +
|
| 48 | + Returns: |
| 49 | + String representation of the normalized index |
| 50 | + """ |
| 51 | + assert k < 0, "This function should only be called for negative indices" |
| 52 | + |
| 53 | + dim_size = fake_value.size(dim_idx) |
| 54 | + # Handle both concrete and symbolic dimension sizes |
| 55 | + if isinstance(dim_size, int): |
| 56 | + normalized_k = k + dim_size |
| 57 | + return repr(normalized_k) |
| 58 | + # For symbolic dimensions, we need to generate the proper expression |
| 59 | + # The state.codegen is a GenerateAST instance which has device_function |
| 60 | + sympy_expr = dim_size._sympy_() + k |
| 61 | + return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})" |
| 62 | + |
| 63 | + |
34 | 64 | class IndexingStrategy:
|
35 | 65 | def codegen_load(
|
36 | 66 | self,
|
@@ -553,7 +583,14 @@ def create(
|
553 | 583 | index_values.append(f"tl.zeros([1], {dtype}){expand}")
|
554 | 584 | output_idx += 1
|
555 | 585 | elif isinstance(k, int):
|
556 |
| - index_values.append(repr(k)) |
| 586 | + # Normalize negative indices |
| 587 | + if k < 0: |
| 588 | + dim_idx = len(index_values) |
| 589 | + index_values.append( |
| 590 | + _normalize_negative_index(k, dim_idx, fake_value, state) |
| 591 | + ) |
| 592 | + else: |
| 593 | + index_values.append(repr(k)) |
557 | 594 | elif isinstance(k, torch.SymInt):
|
558 | 595 | symbol = k._sympy_()
|
559 | 596 | origin = None
|
@@ -839,7 +876,14 @@ def create(
|
839 | 876 | res.offsets.append("0")
|
840 | 877 | res.block_shape.append(1)
|
841 | 878 | elif isinstance(k, int):
|
842 |
| - res.offsets.append(repr(k)) |
| 879 | + # Normalize negative indices |
| 880 | + if k < 0: |
| 881 | + dim_idx = len(res.offsets) |
| 882 | + res.offsets.append( |
| 883 | + _normalize_negative_index(k, dim_idx, fake_value, state) |
| 884 | + ) |
| 885 | + else: |
| 886 | + res.offsets.append(repr(k)) |
843 | 887 | res.block_shape.append(1)
|
844 | 888 | elif isinstance(k, torch.SymInt):
|
845 | 889 | symbol = k._sympy_()
|
|
0 commit comments