@@ -122,19 +122,42 @@ bool Term::containsNameSymbols() const {
122122 return false ;
123123}
124124
125- // / Shortlex order on symbol ranges.
125+ // / Weighted shortlex order on symbol ranges, used for implementing
126+ // / Term::compare() and MutableTerm::compare().
126127// /
127- // / First we compare length, then perform a lexicographic comparison
128- // / on symbols if the two ranges have the same length.
128+ // / We first compute a weight vector for both terms and compare the
129+ // / vectors lexicographically:
130+ // / - Weight of generic param symbols
131+ // / - Number of name symbols
132+ // / - Number of element symbols
129133// /
130- // / This is used to implement Term::compare() and MutableTerm::compare()
131- // / below.
132- static std::optional<int > shortlexCompare (const Symbol *lhsBegin,
133- const Symbol *lhsEnd,
134- const Symbol *rhsBegin,
135- const Symbol *rhsEnd,
136- RewriteContext &ctx) {
137- // First, compare the number of name and pack element symbols.
134+ // / If the terms have the same weight, we compare length.
135+ // /
136+ // / If the terms have the same weight and length, we perform a
137+ // / lexicographic comparison on symbols.
138+ // /
139+ static std::optional<int > compareImpl (const Symbol *lhsBegin,
140+ const Symbol *lhsEnd,
141+ const Symbol *rhsBegin,
142+ const Symbol *rhsEnd,
143+ RewriteContext &ctx) {
144+ ASSERT (lhsBegin != lhsEnd);
145+ ASSERT (rhsBegin != rhsEnd);
146+
147+ // First compare weights on generic parameters. The implicit
148+ // assumption here is we don't form terms with generic parameter
149+ // symbols in the middle, which is true. Otherwise, we'd need
150+ // to add up their weights like we do below for name symbols,
151+ // of course.
152+ if (lhsBegin->getKind () == Symbol::Kind::GenericParam &&
153+ rhsBegin->getKind () == Symbol::Kind::GenericParam) {
154+ unsigned lhsWeight = lhsBegin->getGenericParam ()->getWeight ();
155+ unsigned rhsWeight = rhsBegin->getGenericParam ()->getWeight ();
156+ if (lhsWeight != rhsWeight)
157+ return lhsWeight > rhsWeight ? 1 : -1 ;
158+ }
159+
160+ // Compare the number of name and pack element symbols.
138161 unsigned lhsNameCount = 0 ;
139162 unsigned lhsPackElementCount = 0 ;
140163 for (auto *iter = lhsBegin; iter != lhsEnd; ++iter) {
@@ -192,17 +215,17 @@ static std::optional<int> shortlexCompare(const Symbol *lhsBegin,
192215 return 0 ;
193216}
194217
195- // / Shortlex order on terms. Returns None if the terms are identical except
218+ // / Reduction order on terms. Returns None if the terms are identical except
196219// / for an incomparable superclass or concrete type symbol at the end.
197220std::optional<int > Term::compare (Term other, RewriteContext &ctx) const {
198- return shortlexCompare (begin (), end (), other.begin (), other.end (), ctx);
221+ return compareImpl (begin (), end (), other.begin (), other.end (), ctx);
199222}
200223
201- // / Shortlex order on mutable terms. Returns None if the terms are identical
224+ // / Reduction order on mutable terms. Returns None if the terms are identical
202225// / except for an incomparable superclass or concrete type symbol at the end.
203226std::optional<int > MutableTerm::compare (const MutableTerm &other,
204227 RewriteContext &ctx) const {
205- return shortlexCompare (begin (), end (), other.begin (), other.end (), ctx);
228+ return compareImpl (begin (), end (), other.begin (), other.end (), ctx);
206229}
207230
208231// / Replace the subterm in the range [from,to) of this term with \p rhs.
0 commit comments