Skip to content

Commit 6f07847

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Bail on checking internal overlap when dealing with unbacked symints (pytorch#145385)
Pull Request resolved: pytorch#145385 Approved by: https://github.com/ezyang
1 parent e1407f5 commit 6f07847

File tree

5 files changed

+43
-7
lines changed

5 files changed

+43
-7
lines changed

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@ MemOverlap has_internal_overlap(const TensorBase& tensor) {
1212
MemOverlap has_internal_overlap(TensorImpl* t) {
1313
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
1414

15+
auto sizes = t->sym_sizes();
16+
auto strides = t->sym_strides();
17+
18+
// When we have unbacked symint strides, is_non_overlapping_and_dense
19+
// often results in guard on data dependent errors. For now
20+
// let us bail early if there are unbacked symint strides.
21+
for (const auto i : c10::irange(strides.size())) {
22+
if (!strides[i].has_hint()) {
23+
return MemOverlap::TooHard;
24+
}
25+
}
26+
1527
if (t->is_non_overlapping_and_dense()) {
1628
return MemOverlap::No;
1729
}
1830

19-
auto strides = t->sym_strides();
20-
auto sizes = t->sym_sizes();
2131
for (const auto i : c10::irange(strides.size())) {
2232
// NB: The size oblivious test is written very carefully here. When
2333
// unbacked SymInts are involved, we should try to conservatively report

test/dynamo/test_misc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7614,6 +7614,26 @@ def fn(x, y, z):
76147614
opt = torch.compile(fn, fullgraph=True)
76157615
opt(*inputs)
76167616

7617+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
7618+
@torch._dynamo.config.patch(assume_static_by_default=True)
7619+
def test_symint_copy_into_unbacked_slice(self):
7620+
@torch.compile()
7621+
def fn(a, x):
7622+
u0 = torch.tensor(x[0].to(torch.int64).item()).item()
7623+
B, H, T, D = a.shape
7624+
a_padding = torch.zeros((B, H, u0, D), dtype=torch.float64)
7625+
b = torch.cat([a, a_padding], dim=2)
7626+
c = torch.randn(B, H, 152, D)
7627+
b[:, :, :152, :] = c
7628+
return b
7629+
7630+
x = torch.tensor([0])
7631+
torch._dynamo.decorators.mark_unbacked(x, 0)
7632+
a = torch.zeros((1, 16, 152, 96))
7633+
7634+
# Previously would crash with guard on data dependent error
7635+
fn(a, x)
7636+
76177637
@torch._dynamo.config.patch(capture_scalar_outputs=True)
76187638
def test_symint_fold_nontrivial_product_modulo(self):
76197639
@torch.compile(fullgraph=True)

test/test_dynamic_shapes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def test_debug_has_internal_overlap_unbacked(self):
10661066
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
10671067
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
10681068
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
1069-
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
1069+
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
10701070
Max = torch.sym_max
10711071
self.assertEqual(
10721072
cf(
@@ -1076,7 +1076,7 @@ def test_debug_has_internal_overlap_unbacked(self):
10761076
device="meta",
10771077
)
10781078
),
1079-
0,
1079+
2,
10801080
)
10811081

10821082
# Wobbling these to zero is OK too

torch/_inductor/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def squeeze(x, dim=None):
10091009
for d, s in enumerate(x.get_size()):
10101010
if not (
10111011
d in dims
1012-
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True))
1012+
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
10131013
):
10141014
new_shape.append(s)
10151015

torch/_inductor/sizevars.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,15 @@ def guarded_order(self, seq):
455455
# as this will ensure that you actually have a sympy'ified expression,
456456
# and will prevent you from incorrectly writing evaluate_expr(a == b)
457457
# which does the wrong thing if a or b is a sympy expression
458-
def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:
458+
def evaluate_expr(
459+
self,
460+
left: Union[Expr, sympy.logic.boolalg.Boolean],
461+
size_oblivious: bool = False,
462+
) -> bool:
459463
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
460-
return self.shape_env.evaluate_expr(sympy.sympify(left))
464+
return self.shape_env.evaluate_expr(
465+
sympy.sympify(left), size_oblivious=size_oblivious
466+
)
461467

462468
def evaluate_min(self, left: Expr, right: Expr) -> Expr:
463469
"""return the smaller of left and right, and guard on that choice"""

0 commit comments

Comments
 (0)