@@ -92,8 +92,11 @@ class _SymbolicConstraint:
9292 # Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
9393 cmp : Comparator
9494 debug_str : str # The form in which the user expressed it, for error messages
95- e1 : DimSize # This has been normalized w.r.t. previous constraints only
96- e2 : DimSize # This has been normalized w.r.t. previous constraints only
95+ # e1, e2, and diff == e1 - e2, are normalized w.r.t. previous constraints only
96+ e1 : DimSize
97+ e2 : DimSize
98+ # we pre-compute diff to avoid having the normalization rule kick in later.
99+ diff : DimSize
97100
98101 def __repr__ (self ):
99102 return f"Constraint({ self .debug_str } )"
@@ -1061,29 +1064,33 @@ def _parse_and_process_explicit_constraint(self, c_str: str):
10611064 if cmp == Comparator .GEQ and not is_geq :
10621065 e1 , e2 = e2 , e1
10631066
1064- diff = e1 - e2
1065- if (diff_const := _DimExpr ._to_constant (diff )) is not None :
1066- if ((cmp == Comparator .EQ and diff_const != 0 ) or
1067- (cmp == Comparator .GEQ and diff_const < 0 )):
1068- raise ValueError (f"Unsatisfiable explicit constraint: { c_str } " )
1067+ # Compute e1 - e2 before we add to normalization rules
1068+ constr = _SymbolicConstraint (debug_str = c_str , cmp = cmp , e1 = e1 , e2 = e2 ,
1069+ diff = e1 - e2 )
1070+ self ._process_explicit_constraint (constr )
1071+
1072+ def _process_explicit_constraint (self , constr : _SymbolicConstraint ):
1073+ if (diff_const := _DimExpr ._to_constant (constr .diff )) is not None :
1074+ if ((constr .cmp == Comparator .EQ and diff_const != 0 ) or
1075+ (constr .cmp == Comparator .GEQ and diff_const < 0 )):
1076+ raise ValueError (f"Unsatisfiable explicit constraint: { constr .debug_str } " )
10691077 return
10701078
1071- if cmp == Comparator .EQ :
1072- if not isinstance (e1 , _DimExpr ):
1079+ if constr . cmp == Comparator .EQ :
1080+ if not isinstance (constr . e1 , _DimExpr ):
10731081 raise ValueError ("Invalid equality constraint: {e1} == {e2}. "
10741082 "The left-hand-side must be of the form `term * coefficient`." )
1075- (before , before_k ), * rest = e1 ._sorted_terms
1083+ (before , before_k ), * rest = constr . e1 ._sorted_terms
10761084 if rest :
10771085 raise ValueError ("Invalid equality constraint: {e1} == {e2}. "
10781086 "The left-hand-side must be of the form `term * coefficient`." )
10791087
1080- after = _ensure_poly (e2 , "parse_constraint" , e1 .scope ) # type: ignore[name-error,unused-ignore]
1088+ after = _ensure_poly (constr . e2 , "parse_constraint" , constr . e1 .scope ) # type: ignore[name-error,unused-ignore]
10811089 if before in self ._normalization_rules :
10821090 raise NotImplementedError (
10831091 f"Found multiple equality constraints with the same left-hand-side: { before } " )
10841092 self ._normalization_rules [before ] = (after , before_k )
10851093
1086- constr = _SymbolicConstraint (debug_str = c_str , cmp = cmp , e1 = e1 , e2 = e2 )
10871094 self ._explicit_constraints .append (constr )
10881095
10891096 def _check_same_scope (self , other : _DimExpr ,
@@ -2120,14 +2127,12 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
21202127 for constr in scope ._explicit_constraints :
21212128 # We can't just construct constr.e1 - constr.e2 because for an equality
21222129 # constraint it would be reduced to 0.
2123- c_e1 = constr .e1 ._evaluate (shape_env ) if not core .is_constant_dim (constr .e1 ) else constr .e1 # type: ignore
2124- c_e2 = constr .e2 ._evaluate (shape_env ) if not core .is_constant_dim (constr .e2 ) else constr .e2 # type: ignore
2125- c_diff = c_e1 - c_e2
2130+ c_diff = constr .diff ._evaluate (shape_env ) if not core .is_constant_dim (constr .diff ) else constr .diff # type: ignore
21262131 shape_constraints .add_constraint (
21272132 constr .cmp , c_diff , 0 ,
21282133 error_message_pieces = [
21292134 f"Input shapes do not match the symbolic shape constraint { constr .debug_str } . "
2130- f"Expected '{ constr .e1 } - { constr . e2 } ' to be "
2135+ f"Expected '{ constr .diff } ' to be "
21312136 f"{ 'greater or equal' if constr .cmp == Comparator .GEQ else 'equal' } to 0, "
21322137 "but found " , c_diff ,
21332138
0 commit comments