Skip to content

Commit 60f9da5

Browse files
committed
[shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization rule. This created a problem for constraints of the form "4*b=c", because it would not allow proving that "b <= c" (since the normalization of "4*b" kicks in only if "b" is multiplied by a multiple of 4. Now we add the equality constraints also in the inequality reasoning state.
1 parent cfdac00 commit 60f9da5

File tree

3 files changed

+46
-38
lines changed

3 files changed

+46
-38
lines changed

jax/_src/export/shape_poly.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ class _SymbolicConstraint:
9292
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
9393
cmp: Comparator
9494
debug_str: str # The form in which the user expressed it, for error messages
95-
e1: DimSize # This has been normalized w.r.t. previous constraints only
96-
e2: DimSize # This has been normalized w.r.t. previous constraints only
95+
# e1, e2, and diff == e1 - e2, are normalized w.r.t. previous constraints only
96+
e1: DimSize
97+
e2: DimSize
98+
# we pre-compute diff to avoid having the normalization rule kick in later.
99+
diff: DimSize
97100

98101
def __repr__(self):
99102
return f"Constraint({self.debug_str})"
@@ -1061,29 +1064,33 @@ def _parse_and_process_explicit_constraint(self, c_str: str):
10611064
if cmp == Comparator.GEQ and not is_geq:
10621065
e1, e2 = e2, e1
10631066

1064-
diff = e1 - e2
1065-
if (diff_const := _DimExpr._to_constant(diff)) is not None:
1066-
if ((cmp == Comparator.EQ and diff_const != 0) or
1067-
(cmp == Comparator.GEQ and diff_const < 0)):
1068-
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
1067+
# Compute e1 - e2 before we add to normalization rules
1068+
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2,
1069+
diff=e1 - e2)
1070+
self._process_explicit_constraint(constr)
1071+
1072+
def _process_explicit_constraint(self, constr: _SymbolicConstraint):
1073+
if (diff_const := _DimExpr._to_constant(constr.diff)) is not None:
1074+
if ((constr.cmp == Comparator.EQ and diff_const != 0) or
1075+
(constr.cmp == Comparator.GEQ and diff_const < 0)):
1076+
raise ValueError(f"Unsatisfiable explicit constraint: {constr.debug_str}")
10691077
return
10701078

1071-
if cmp == Comparator.EQ:
1072-
if not isinstance(e1, _DimExpr):
1079+
if constr.cmp == Comparator.EQ:
1080+
if not isinstance(constr.e1, _DimExpr):
10731081
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
10741082
"The left-hand-side must be of the form `term * coefficient`.")
1075-
(before, before_k), *rest = e1._sorted_terms
1083+
(before, before_k), *rest = constr.e1._sorted_terms
10761084
if rest:
10771085
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
10781086
"The left-hand-side must be of the form `term * coefficient`.")
10791087

1080-
after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore]
1088+
after = _ensure_poly(constr.e2, "parse_constraint", constr.e1.scope) # type: ignore[name-error,unused-ignore]
10811089
if before in self._normalization_rules:
10821090
raise NotImplementedError(
10831091
f"Found multiple equality constraints with the same left-hand-side: {before}")
10841092
self._normalization_rules[before] = (after, before_k)
10851093

1086-
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
10871094
self._explicit_constraints.append(constr)
10881095

10891096
def _check_same_scope(self, other: _DimExpr,
@@ -2120,14 +2127,12 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
21202127
for constr in scope._explicit_constraints:
21212128
# We can't just construct constr.e1 - constr.e2 because for an equality
21222129
# constraint it would be reduced to 0.
2123-
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
2124-
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
2125-
c_diff = c_e1 - c_e2
2130+
c_diff = constr.diff._evaluate(shape_env) if not core.is_constant_dim(constr.diff) else constr.diff # type: ignore
21262131
shape_constraints.add_constraint(
21272132
constr.cmp, c_diff, 0,
21282133
error_message_pieces=[
21292134
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
2130-
f"Expected '{constr.e1} - {constr.e2}' to be "
2135+
f"Expected '{constr.diff}' to be "
21312136
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
21322137
"but found ", c_diff,
21332138

jax/_src/export/shape_poly_decision.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,11 @@ def initialize(self) -> _DecisionByElimination:
8585
# the result (albeit, for now, without a good feedback loop to understand
8686
# how the order matters for inequalities).
8787
for constr in self.scope._explicit_constraints:
88-
if not core.is_constant_dim(constr.e1):
89-
self.add_implicit_constraints_expr(constr.e1) # type: ignore
90-
if not core.is_constant_dim(constr.e2):
91-
self.add_implicit_constraints_expr(constr.e2) # type: ignore
92-
# The equality constraints are not needed for inequality decisions,
93-
# because the LHS should always be rewritten in terms of the RHS.
94-
# In fact, adding them may break the assumption that if we eliminate
95-
# the leading term we end up with only smaller terms, because the LHS
96-
# may appear in the rest and may be rewritten to something larger.
97-
# However, we want to add the implicit constraints within.
98-
if constr.cmp == Comparator.GEQ:
99-
self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0,
100-
constr.debug_str)
88+
if not core.is_constant_dim(constr.diff):
89+
self.add_implicit_constraints_expr(constr.diff) # type: ignore
90+
91+
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
92+
constr.debug_str)
10193

10294

10395
# Clear the cache, since we have added constraints.
@@ -197,7 +189,7 @@ def combine_term_with_existing(self, t: _DimTerm, t_k: int, *,
197189
Combine a term with existing constraints.
198190
For input (t, t_k) the tuple (c_eq, c, c_s, t_s) is among the returned
199191
tuples if there exists a constraint `c =[c_eq] 0` that can be combined
200-
with `t*t_k` to eliminate `t`.
192+
with `t*t_k` to eliminate `t`, and:
201193
202194
* `c =[c_eq] 0`
203195
* The term `comb = t*t_k*t_s + c*c_s` does not contain `t`, and if
@@ -207,7 +199,7 @@ def combine_term_with_existing(self, t: _DimTerm, t_k: int, *,
207199
"""
208200
# TODO: maybe a generator is useful here instead of materializing the list
209201
acc: list[tuple[Comparator, _DimExpr, int, int]] = []
210-
# First combine with the existing term constraints
202+
# First combine with the existing term bounds
211203
t_lb, t_ub = self._term_bounds.get(t, (-np.inf, np.inf))
212204
if t_lb == t_ub:
213205
acc.append((Comparator.EQ, _DimExpr(((t, 1),), scope) - int(t_lb),

tests/shape_poly_test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,27 +1114,38 @@ def test_constraints_eq_threefry(self):
11141114
self.assertEqual(x_reshaped, (a + a % 2) // -2)
11151115
self.assertEqual(2 * x_reshaped, a)
11161116

1117-
def test_constraints_a_minus_4d_eq(self):
1117+
def test_constraints_eq_a_minus_4d(self):
11181118
# simulates d = div(a, 4) and m = mod(a, 4)
1119-
assumptions = ["4*d == a - m", "m >= 0", "m <= 3"]
1120-
scope = shape_poly.SymbolicScope(assumptions)
1119+
constraints = ["4*d == a - m", "m >= 0", "m <= 3"]
1120+
scope = shape_poly.SymbolicScope(constraints)
11211121
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
11221122
self.assertEqual(_bounds(a - 4*d), (1, 3)) # a - 4d = m >= 1
11231123
# TODO: The incompleteness is due to the way we combine external constraints
11241124
self.assertEqual(_bounds(a - 2*d),
11251125
_expect(best=(3, np.inf), current=(-np.inf, np.inf))) # a - 2d = m + 2d >= 3
11261126
# TODO: The incompleteness is due to the way we combine external constraints
11271127
self.assertEqual(_bounds(a),
1128-
_expect(best=(5, np.inf), current=(1, np.inf))) # a >= 4d + m >= 5
1128+
_expect(best=(5, np.inf), current=(4, np.inf))) # a >= 4d + m >= 5
11291129

11301130
# Now with a different order of constraints
1131-
assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"]
1132-
scope1 = shape_poly.SymbolicScope(assumptions1)
1131+
constraints1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"]
1132+
scope1 = shape_poly.SymbolicScope(constraints1)
11331133
a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1)
11341134
self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1
11351135
self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3
11361136
self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5
11371137

1138+
def test_constraints_eq_geq(self):
1139+
# We ensure that an equality constraint it is usable not just for
1140+
# normalization but also for inequality reasoning.
1141+
a, b = export.symbolic_shape(
1142+
"a, b", constraints=["4 * a == b"])
1143+
self.assertGreaterEqual(b, a)
1144+
self.assertGreaterEqual(b, 3*a)
1145+
self.assertGreaterEqual(b, 4 * a)
1146+
self.assertGreaterEqual(5 * a, b)
1147+
self.assertGreaterEqual(9 * a, 2*b)
1148+
11381149
def test_constraints_error_msg(self):
11391150
a, b = shape_poly.symbolic_shape("a, b",
11401151
constraints=("a >= 5",))
@@ -1713,7 +1724,7 @@ def f(x): # x: i32[a]
17131724

17141725
with self.assertRaisesRegex(
17151726
ValueError,
1716-
re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")):
1727+
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
17171728
exp.call(np.arange(5, dtype=np.int32))
17181729

17191730
def test_constraints_eq_0_compile_time_check(self):

0 commit comments

Comments
 (0)