Skip to content

Commit c10975d

Browse files
Revert "Avoid DDE in narrow with unbacked start (pytorch#166361)"
This reverts commit c761999. Reverted pytorch#166361 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#166361 (comment)))
1 parent 68e31e2 commit c10975d

File tree

8 files changed

+18
-180
lines changed

8 files changed

+18
-180
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <ATen/core/ATen_fwd.h>
22
#include <c10/core/ScalarType.h>
3-
#include <c10/core/SymInt.h>
43
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
54
#include <ATen/AccumulateType.h>
65
#include <ATen/Dispatch.h>
@@ -1711,46 +1710,19 @@ Tensor narrow_symint(
17111710
"], but got ",
17121711
start,
17131712
")")
1714-
// Bounds check without converting start:
1715-
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
1716-
// length <= 0
1717-
// - If start >= 0: need start + length <= cur_size
1718-
auto end = start + length;
1713+
if (start < 0) {
1714+
start = start + cur_size;
1715+
}
17191716
TORCH_SYM_CHECK(
1720-
(start.sym_lt(0).sym_and((end).sym_le(0)))
1721-
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
1717+
start.sym_le(cur_size - length),
17221718
"start (",
17231719
start,
17241720
") + length (",
17251721
length,
17261722
") exceeds dimension size (",
17271723
cur_size,
17281724
").");
1729-
1730-
if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
1731-
return at::slice_symint(self, dim, start, end, 1);
1732-
} else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
1733-
// Avoid the complex symbolic expressions path for non-unbacked.
1734-
return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
1735-
} else {
1736-
// Cannot statically determine the condition due to unbacked.
1737-
// This is an interesting situation; when start is negative and
1738-
// start + length == 0, slice and narrow do different things.
1739-
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
1740-
// pass curr_size instead of 0. Otherwise, they would do the same thing.
1741-
// This says at runtime: if start < 0 and end == 0, then pass curr_size
1742-
// instead of 0.
1743-
1744-
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
1745-
auto result =
1746-
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
1747-
1748-
// Ensure slice allocated unbacked size is specialized to length.
1749-
SymInt new_size = result.sym_size(dim);
1750-
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
1751-
1752-
return result;
1753-
}
1725+
return at::slice_symint(self, dim, start, start + length, 1);
17541726
}
17551727

17561728
// This overload exists purely for XLA, because they wanted to pass in

c10/core/SymBool.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <c10/core/SymBool.h>
2-
#include <c10/core/SymInt.h>
32
#include <c10/core/SymNodeImpl.h>
43

54
namespace c10 {
@@ -112,17 +111,4 @@ bool SymBool::has_hint() const {
112111
return toSymNodeImpl()->has_hint();
113112
}
114113

115-
SymInt SymBool::toSymInt() const {
116-
// If concrete bool, return concrete SymInt
117-
if (auto ma = maybe_as_bool()) {
118-
return SymInt(*ma ? 1 : 0);
119-
}
120-
121-
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
122-
auto node = toSymNodeImpl();
123-
auto one_node = node->wrap_int(1);
124-
auto zero_node = node->wrap_int(0);
125-
return SymInt(node->sym_ite(one_node, zero_node));
126-
}
127-
128114
} // namespace c10

c10/core/SymBool.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
namespace c10 {
1414

15-
class SymInt;
16-
1715
class C10_API SymBool {
1816
public:
1917
/*implicit*/ SymBool(bool b) : data_(b) {}
@@ -82,10 +80,6 @@ class C10_API SymBool {
8280
return toSymNodeImplUnowned()->constant_bool();
8381
}
8482

85-
// Convert SymBool to SymInt (0 or 1)
86-
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
87-
SymInt toSymInt() const;
88-
8983
bool is_heap_allocated() const {
9084
return ptr_;
9185
}

test/export/test_export.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6093,19 +6093,26 @@ def forward(self, x, y, fixes):
60936093
retry_export(
60946094
cf_implicitsize(),
60956095
(torch.tensor(2), torch.randn(10)),
6096-
fixes=[],
6096+
fixes=[
6097+
# Could not guard on data-dependent expression u0 < 0
6098+
"torch._check(i >= 0)",
6099+
],
60976100
)
60986101

60996102
class cf_stacklist(torch.nn.Module):
61006103
def forward(self, xs, y, fixes):
61016104
i = y.item()
61026105
eval(fixes)
6106+
# instead of xs[i]
61036107
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
61046108

61056109
retry_export(
61066110
cf_stacklist(),
61076111
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
6108-
fixes=[],
6112+
fixes=[
6113+
# Could not guard on data-dependent expression u0 < 0
6114+
"torch._check(i >= 0)",
6115+
],
61096116
)
61106117

61116118
class cf_tensorsplit(torch.nn.Module):
@@ -6159,12 +6166,7 @@ def test_no_suggested_fixes_for_data_dependent_errors(self):
61596166
class cf_stacklist(torch.nn.Module):
61606167
def forward(self, xs, y):
61616168
# y.item() is not a local, so we can't suggest a fix
6162-
if y.item() < 0:
6163-
return (
6164-
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
6165-
)
6166-
else:
6167-
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
6169+
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
61686170

61696171
with self.assertRaisesRegex(
61706172
error_type,
@@ -6194,18 +6196,7 @@ class cf_stacklist_udd(torch.nn.Module):
61946196
def forward(self, xs, y):
61956197
box = Box(y.item())
61966198
# box.content is not a local, so we can't suggest a fix
6197-
if box.content < 0:
6198-
return (
6199-
torch.stack(xs, 0)
6200-
.narrow(0, box.content + xs.size(), 1)
6201-
.squeeze()
6202-
)
6203-
else:
6204-
return (
6205-
torch.stack(xs, 0)
6206-
.narrow(0, box.content + xs.size(), 1)
6207-
.squeeze()
6208-
)
6199+
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
62096200

62106201
with self.assertRaisesRegex(
62116202
error_type,

test/test_dynamic_shapes.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4401,57 +4401,6 @@ def func(x, y):
44014401

44024402
self.assertEqual(compiled(a, b), func(a, b))
44034403

4404-
@fresh_cache()
4405-
@torch._dynamo.config.patch("capture_scalar_outputs", True)
4406-
def test_narrow_unbacked_start(self):
4407-
def func(x, start, length):
4408-
# unbacked start
4409-
u0 = start.item()
4410-
return torch.narrow(x, 0, u0, length)
4411-
4412-
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
4413-
4414-
x = torch.tensor([1, 2, 3, 4, 5, 6])
4415-
4416-
# Test cases: (start, length)
4417-
test_cases = [
4418-
# Negative starts
4419-
(-2, 2), # Start from second-to-last element
4420-
(-1, 1), # Start from last element
4421-
(-3, 3), # Start from third-to-last element
4422-
(-6, 2), # Start from beginning (negative)
4423-
(-4, 1), # Start from fourth-to-last element
4424-
# Positive starts
4425-
(0, 2), # Start from beginning
4426-
(1, 3), # Start from second element
4427-
(2, 2), # Start from third element
4428-
(4, 2), # Start near end
4429-
# Edge cases
4430-
(0, 6), # Full tensor
4431-
(0, 1), # Single element from start
4432-
(5, 1), # Single element from end
4433-
]
4434-
4435-
for start_val, length in test_cases:
4436-
with self.subTest(start=start_val, length=length):
4437-
start = torch.tensor([start_val])
4438-
4439-
# Test with compiled function
4440-
result_compiled = compiled_func(x, start, length)
4441-
4442-
# Test with eager function (expected behavior)
4443-
result_eager = func(x, start, length)
4444-
4445-
# Compare results
4446-
self.assertEqual(result_compiled, result_eager)
4447-
4448-
@fresh_cache()
4449-
@torch._dynamo.config.patch("capture_scalar_outputs", True)
4450-
@torch._inductor.config.patch("cpp_wrapper", True)
4451-
def test_narrow_unbacked_start_cpp_wrapper(self):
4452-
"""Test narrow with unbacked start with cpp_wrapper"""
4453-
self.test_narrow_unbacked_start()
4454-
44554404

44564405
instantiate_parametrized_tests(TestUnbacked)
44574406

torch/_inductor/codegen/wrapper.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,8 +2058,7 @@ def clamp_index(x):
20582058
neg = self.codegen_sizevar(
20592059
sympy.Max(0, sympy.Min(x + node.size, node.size))
20602060
)
2061-
x_cond = self.codegen_sizevar(x)
2062-
return f"{pos} if {x_cond} >= 0 else {neg}"
2061+
return f"{pos} if {x} >= 0 else {neg}"
20632062

20642063
def codegen_with_step(start_var, end_var, step):
20652064
if step == 1:

torch/fx/experimental/symbolic_shapes.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ def rebind_unbacked(
547547
assert shape_env is not None
548548
for raw_u0, path in bindings.items():
549549
u1 = pytree.key_get(result, path)
550-
551550
# Sometimes, things were previously unbacked bindings become constants.
552551
# There are two situations this can happen.
553552
#
@@ -603,23 +602,7 @@ def rebind_unbacked(
603602
if u1.node.hint is not None:
604603
continue
605604

606-
# unbacked symbols bindings might be replaced to other backed or
607-
# unbacked replacements.
608-
#
609-
# Example:
610-
# u = x.item()
611-
# torch._check(u == 5)
612-
#
613-
# The safest approach is to retrieve raw_u1 from u1.node._expr
614-
# and perform the rebinding on the original unbacked symbol,
615-
# even if it’s no longer directly referenced.
616-
#
617-
# In other words, we should always rebind the original symbol
618-
# before any replacements are applied.
619-
# u0 -> u0 == s1
620-
raw_u1 = u1.node._expr
621-
622-
# TODO Do we still need this logic below?
605+
raw_u1 = u1.node.expr
623606
# Simplify SymBool binding
624607
if (
625608
isinstance(raw_u1, sympy.Piecewise)

torch/utils/_sympy/printers.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -306,24 +306,6 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
306306
raise TypeError("ndigits must be an instance of sympy.Integer")
307307
return f"round({self._print(number)}, {ndigits})"
308308

309-
def _print_Piecewise(self, expr: sympy.Expr) -> str:
310-
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
311-
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
312-
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
313-
result = None
314-
for expr_i, cond_i in reversed(expr.args):
315-
expr_str = self._print(expr_i)
316-
if cond_i == True: # noqa: E712
317-
# This is the default case
318-
result = expr_str
319-
else:
320-
cond_str = self._print(cond_i)
321-
if result is None:
322-
result = expr_str
323-
else:
324-
result = f"({expr_str} if {cond_str} else {result})"
325-
return result if result else "0"
326-
327309

328310
class CppPrinter(ExprPrinter):
329311
def _print_Integer(self, expr: sympy.Expr) -> str:
@@ -345,24 +327,6 @@ def _print_Where(self, expr: sympy.Expr) -> str:
345327
)
346328
return f"{c} ? {p} : {q}"
347329

348-
def _print_Piecewise(self, expr: sympy.Expr) -> str:
349-
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
350-
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
351-
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
352-
result = None
353-
for expr_i, cond_i in reversed(expr.args):
354-
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
355-
if cond_i == True: # noqa: E712
356-
# This is the default case
357-
result = expr_str
358-
else:
359-
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
360-
if result is None:
361-
result = expr_str
362-
else:
363-
result = f"{cond_str} ? {expr_str} : {result}"
364-
return f"({result})" if result else "0"
365-
366330
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
367331
x, div, mod = expr.args
368332
x = self.doprint(x)

0 commit comments

Comments
 (0)