Skip to content

Commit a63ab0b

Browse files
karthickaipytorchmergebot
authored andcommitted
[Inductor] Fix out-of-bounds indices in repeat_interleave decomposition (pytorch#165368)
When `repeat_interleave` is decomposed into: ```bash cumsum = repeat.cumsum(0) pos = torch.arange(output_size, device=repeat.device) indices = torch.searchsorted(cumsum, pos, right=True) ``` `searchsorted` op with `right=True` returns the insertion point after matching elements. When query values `pos` are `>= cumsum[-1]`, searchsorted returns `len(cumsum)`, which is out of bounds for indexing (valid range: `[0, len(cumsum)-1]`). These invalid indices trigger CUDA device-side assert errors in downstream indexing operations. This fix adds clamping to ensure all indices stay within the valid range [0, repeat.size(0)-1]. Pull Request resolved: pytorch#165368 Approved by: https://github.com/mlazos
1 parent 102b788 commit a63ab0b

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14268,6 +14268,38 @@ def fn(a, b):
1426814268
self.assertTrue("'enable_fp_fusion': False" in code)
1426914269
torch.testing.assert_close(out, fn(a, b), atol=0, rtol=0)
1427014270

14271+
@skip_if_cpp_wrapper("skip cpp wrapper")
14272+
@requires_cuda_and_triton
14273+
def test_repeat_interleave_decomposition_has_clamp(self):
14274+
repeat = torch.ones(2560, dtype=torch.int64, device=GPU_TYPE)
14275+
output_size = 505450
14276+
data = torch.arange(2560, device=GPU_TYPE)
14277+
14278+
if is_dynamic_shape_enabled():
14279+
raise unittest.SkipTest(
14280+
"repeat_interleave decomp doesn't support dynamic output size"
14281+
)
14282+
14283+
@torch.compile
14284+
def fn(repeat, output_size, data):
14285+
indices = torch.ops.aten.repeat_interleave.Tensor(
14286+
repeat, output_size=output_size
14287+
)
14288+
return data[indices]
14289+
14290+
result, code = run_and_get_code(fn, repeat, output_size, data)
14291+
14292+
self.assertEqual(result.shape[0], output_size)
14293+
self.assertTrue(torch.all(result >= 0).item())
14294+
self.assertTrue(torch.all(result < 2560).item())
14295+
14296+
code_str = "\n".join(code)
14297+
self.assertIn(
14298+
"triton_helpers.minimum",
14299+
code_str,
14300+
"Generated Triton code should use triton_helpers.minimum for clamping",
14301+
)
14302+
1427114303
# end of class CommonTemplate - add new tests here
1427214304

1427314305

torch/_inductor/decomposition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,9 +1188,10 @@ def repeat_interleave_Tensor(
11881188
assert repeat.ndim == 1
11891189
cumsum = repeat.cumsum(0)
11901190
pos = torch.arange(output_size, device=repeat.device)
1191-
return torch.searchsorted(
1191+
indices = torch.searchsorted(
11921192
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
11931193
)
1194+
return torch.clamp(indices, max=repeat.size(0) - 1)
11941195

11951196

11961197
# intentionally not regiestered

0 commit comments

Comments
 (0)