Skip to content

Commit 3f83e89

Browse files
sevenEngpytorchmergebot
authored andcommitted
[inductor] fix issue for example value with unbacked strides (pytorch#163660)
## Issue During autotune, we're not applying size hints atomically for the example inputs used for benchmarking. If there is unbacked symint showing up in inputs' strides, this might lead to CUDA IMA, and this could be reproduced by the added unittest, with stride being `[128 * u0, 128, 1]` and unbacked fallback being 8192, after calling `benchmark_example_value`, we get back a tensor with stride as `[8192, 128, 1]` as opposed to `[128 * 8192, 128, 1]` ## Fix Using the atomic API when trying to apply size hints to input tensor' strides. Pull Request resolved: pytorch#163660 Approved by: https://github.com/ColinPeppler
1 parent d7e3f49 commit 3f83e89

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

test/inductor/test_unbacked_symints.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,28 @@ def fn(x):
653653
expected = fn(*example_inputs)
654654
torch.testing.assert_close(actual, expected)
655655

656+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
657+
@inductor_config.patch({"max_autotune": True})
658+
@dynamo_config.patch({"capture_scalar_outputs": True})
659+
def test_autotune_with_unbacked_stride(self, device):
660+
def fn(x, y, a):
661+
u0 = a.item()
662+
torch._check(u0 != 1)
663+
unbacked = x.expand(8, u0, *x.shape).clone()
664+
unbacked = torch.permute(unbacked, [0, 2, 1])
665+
y = y.expand(8, *y.shape)
666+
bmm = torch.ops.aten.bmm(unbacked, y)
667+
return bmm
668+
669+
example_inputs = (
670+
torch.randn((32,), dtype=torch.bfloat16, device=device),
671+
torch.randn((128, 64), dtype=torch.bfloat16, device=device),
672+
torch.tensor(128, device=device),
673+
)
674+
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
675+
expected = fn(*example_inputs)
676+
torch.testing.assert_close(actual, expected)
677+
656678

657679
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
658680

torch/_inductor/select_algorithm.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3622,10 +3622,13 @@ def benchmark_example_value(node, hint_override: Optional[int] = None):
36223622
fallback=config.unbacked_symint_fallback,
36233623
hint_override=hint_override,
36243624
),
3625-
V.graph.sizevars.size_hints(
3626-
node.get_stride(),
3627-
fallback=config.unbacked_symint_fallback,
3628-
hint_override=hint_override,
3625+
tuple(
3626+
V.graph.sizevars.atomically_apply_size_hint(
3627+
stride,
3628+
fallback=config.unbacked_symint_fallback,
3629+
hint_override=hint_override,
3630+
)
3631+
for stride in node.get_stride()
36293632
),
36303633
node.get_device(),
36313634
node.get_dtype(),
@@ -3677,9 +3680,12 @@ def key_of(node):
36773680
node.get_size(),
36783681
fallback=config.unbacked_symint_fallback,
36793682
),
3680-
*sizevars.size_hints(
3681-
node.get_stride(),
3682-
fallback=config.unbacked_symint_fallback,
3683+
*tuple(
3684+
V.graph.sizevars.atomically_apply_size_hint(
3685+
stride,
3686+
fallback=config.unbacked_symint_fallback,
3687+
)
3688+
for stride in node.get_stride()
36833689
),
36843690
sizevars.size_hint(
36853691
node.get_layout().offset,

torch/_inductor/sizevars.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,11 @@ def _sub_unbacked_exprs(self, expr: Expr) -> Expr:
908908
return expr
909909

910910
def atomically_apply_size_hint(
911-
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
911+
self,
912+
expr: Union[Expr, int],
913+
*,
914+
fallback: Optional[int] = None,
915+
hint_override: Optional[int] = None,
912916
) -> Union[Expr, int]:
913917
if isinstance(expr, (int, sympy.Integer)):
914918
return int(expr)
@@ -925,7 +929,9 @@ def atomically_apply_size_hint(
925929
assert isinstance(expr, Expr), type(expr)
926930
free_symbols = expr.free_symbols
927931
size_dict = {
928-
symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback)
932+
symbol: V.graph.sizevars.size_hint(
933+
symbol, fallback=fallback, hint_override=hint_override
934+
)
929935
for symbol in free_symbols
930936
}
931937
return expr.subs(size_dict)

0 commit comments

Comments
 (0)