|
31 | 31 | ShapeLike = Sequence[SymIntLike]
|
32 | 32 |
|
33 | 33 |
|
| 34 | +def _normalize_negative_index( |
| 35 | + k: int, |
| 36 | + dim_idx: int, |
| 37 | + fake_value: torch.Tensor, |
| 38 | + state: CodegenState, |
| 39 | +) -> str: |
| 40 | + """Normalize negative indices to positive ones. |
| 41 | +
|
| 42 | + Args: |
| 43 | + k: The negative index value |
| 44 | + dim_idx: The dimension index |
| 45 | + fake_value: The fake tensor to get dimension size from |
| 46 | + state: The codegen state |
| 47 | +
|
| 48 | + Returns: |
| 49 | + String representation of the normalized index |
| 50 | + """ |
| 51 | + assert k < 0, "This function should only be called for negative indices" |
| 52 | + |
| 53 | + dim_size = fake_value.size(dim_idx) |
| 54 | + # Handle both concrete and symbolic dimension sizes |
| 55 | + if isinstance(dim_size, int): |
| 56 | + normalized_k = k + dim_size |
| 57 | + return repr(normalized_k) |
| 58 | + # For symbolic dimensions, we need to generate the proper expression |
| 59 | + # The state.codegen is a GenerateAST instance which has device_function |
| 60 | + sympy_expr = dim_size._sympy_() + k |
| 61 | + return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})" |
| 62 | + |
| 63 | + |
34 | 64 | class IndexingStrategy:
|
35 | 65 | def codegen_load(
|
36 | 66 | self,
|
@@ -555,7 +585,14 @@ def create(
|
555 | 585 | index_values.append(f"tl.zeros([1], {dtype}){expand}")
|
556 | 586 | output_idx += 1
|
557 | 587 | elif isinstance(k, int):
|
558 |
| - index_values.append(repr(k)) |
| 588 | + # Normalize negative indices |
| 589 | + if k < 0: |
| 590 | + dim_idx = len(index_values) |
| 591 | + index_values.append( |
| 592 | + _normalize_negative_index(k, dim_idx, fake_value, state) |
| 593 | + ) |
| 594 | + else: |
| 595 | + index_values.append(repr(k)) |
559 | 596 | elif isinstance(k, torch.SymInt):
|
560 | 597 | symbol = k._sympy_()
|
561 | 598 | origin = None
|
@@ -843,7 +880,14 @@ def create(
|
843 | 880 | res.offsets.append("0")
|
844 | 881 | res.block_shape.append(1)
|
845 | 882 | elif isinstance(k, int):
|
846 |
| - res.offsets.append(repr(k)) |
| 883 | + # Normalize negative indices |
| 884 | + if k < 0: |
| 885 | + dim_idx = len(res.offsets) |
| 886 | + res.offsets.append( |
| 887 | + _normalize_negative_index(k, dim_idx, fake_value, state) |
| 888 | + ) |
| 889 | + else: |
| 890 | + res.offsets.append(repr(k)) |
847 | 891 | res.block_shape.append(1)
|
848 | 892 | elif isinstance(k, torch.SymInt):
|
849 | 893 | symbol = k._sympy_()
|
|
0 commit comments