Skip to content

Commit b33664e

Browse files
committed
Improve pattern matching performance.
1 parent 543bec1 commit b33664e

File tree

2 files changed

+142
-12
lines changed

2 files changed

+142
-12
lines changed

vm/vm/main/unify.cc

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ struct StructuralDualWalk {
6262
inline
6363
void undoBindings(VM vm);
6464

65-
inline
66-
void doCapture(VM vm, RichNode value, RichNode capture);
67-
6865
inline
6966
void doConjunction(VM vm, RichNode value, RichNode conjunction);
7067

@@ -103,6 +100,130 @@ bool fullPatternMatch(VM vm, RichNode value, RichNode pattern,
103100
return walk.run(vm, value, pattern);
104101
}
105102

103+
////////////////////////////
104+
// Quick pattern matching //
105+
////////////////////////////
106+
107+
static void doCapture(VM vm, RichNode value, RichNode capture,
108+
StaticArray<UnstableNode> captures) {
109+
nativeint index = capture.as<PatMatCapture>().index();
110+
if (index >= 0)
111+
captures[(size_t) index].copy(vm, value);
112+
}
113+
114+
// This function identifies some common patterns and performs match before
115+
// needing to do full stack walking.
116+
PatternMatchResult quickPatternMatch(VM vm, RichNode value, RichNode pattern,
117+
StaticArray<UnstableNode> captures) {
118+
if (value.isSameNode(pattern))
119+
return PatternMatchResult::succeed;
120+
121+
auto valueType = value.type();
122+
auto patternType = pattern.type();
123+
124+
// Bind a bare PatMatCapture directly.
125+
if (patternType == PatMatCapture::type()) {
126+
doCapture(vm, value, pattern, captures);
127+
return PatternMatchResult::succeed;
128+
}
129+
130+
auto valueBehavior = valueType.getStructuralBehavior();
131+
auto patternBehavior = patternType.getStructuralBehavior();
132+
133+
// The following are taken from `equals()`...
134+
if (valueBehavior == sbVariable || patternBehavior == sbVariable)
135+
return PatternMatchResult::unknown;
136+
137+
// Let full pattern matching work with more complex pat-mat objects.
138+
if (patternType == PatMatConjunction::type())
139+
return PatternMatchResult::unknown;
140+
if (patternType == PatMatOpenRecord::type())
141+
return PatternMatchResult::unknown;
142+
143+
// At this point, we can safely short-circruit anything with different shape.
144+
if (valueType != patternType)
145+
return PatternMatchResult::failed;
146+
147+
switch (valueBehavior) {
148+
// These two are the same as `equals()`.
149+
case sbValue:
150+
if (ValueEquatable(value).equals(vm, pattern))
151+
return PatternMatchResult::succeed;
152+
else
153+
return PatternMatchResult::failed;
154+
155+
case sbTokenEq:
156+
return PatternMatchResult::failed;
157+
158+
case sbStructural: {
159+
// Majority of pattern matches are tuples like r(X Y Z), or cons like H|T
160+
if (valueType == Tuple::type()) {
161+
162+
// For tuples, first ensure labels and widths are the same.
163+
auto valueTuple = value.as<Tuple>();
164+
auto patternTuple = pattern.as<Tuple>();
165+
auto label = valueTuple.label(vm);
166+
auto width = valueTuple.width(vm);
167+
if (!patternTuple.testTuple(vm, label, width))
168+
return PatternMatchResult::failed;
169+
170+
// Then check that the patterns only contain PatMatCaptures, otherwise
171+
// we fallback to fullPatternMatch().
172+
auto patternsArray = patternTuple.getElementsArray();
173+
for (size_t i = 0; i < width; ++ i) {
174+
if (!RichNode(patternsArray[i]).is<PatMatCapture>())
175+
return PatternMatchResult::unknown;
176+
}
177+
178+
// Finally perform the capture and return.
179+
auto valuesArray = valueTuple.getElementsArray();
180+
for (size_t i = 0; i < width; ++ i) {
181+
doCapture(vm, valuesArray[i], patternsArray[i], captures);
182+
}
183+
return PatternMatchResult::succeed;
184+
185+
} else if (valueType == Cons::type()) {
186+
187+
// Cons matching usually comes in these forms:
188+
// 1. H|T
189+
// 2. X#Y|T
190+
// 3. a|T
191+
// So here we ensure the tail is PatMatCapture, and the head is a value
192+
// type, PatMatCapture (also a value type) or a tuple (but not cons).
193+
// Then we perform capturing on the head and tail respectively.
194+
195+
auto patternsArray = pattern.as<Cons>().getElementsArray();
196+
auto valuesArray = value.as<Cons>().getElementsArray();
197+
198+
if (!RichNode(patternsArray[1]).is<PatMatCapture>())
199+
return PatternMatchResult::unknown;
200+
201+
RichNode patternHead(patternsArray[0]);
202+
auto patternHeadType = patternHead.type();
203+
if (patternHeadType != Tuple::type() && patternHeadType.getStructuralBehavior() != sbValue)
204+
return PatternMatchResult::unknown;
205+
206+
// Since head is never Cons, the recursion depth will be limited to 2.
207+
auto res = quickPatternMatch(vm, valuesArray[0], patternHead, captures);
208+
if (res == PatternMatchResult::succeed) {
209+
// Perform capture only if we are sure the heads match.
210+
doCapture(vm, valuesArray[1], patternsArray[1], captures);
211+
}
212+
return res;
213+
214+
} else {
215+
216+
// We don't care about more complex cases.
217+
return PatternMatchResult::unknown;
218+
219+
}
220+
}
221+
222+
default:
223+
return PatternMatchResult::unknown;
224+
}
225+
}
226+
106227
////////////////////
107228
// The real thing //
108229
////////////////////
@@ -263,7 +384,7 @@ bool StructuralDualWalk::processPair(VM vm, RichNode left, RichNode right) {
263384
// Handle captures
264385
if (kind == wkPatternMatch) {
265386
if (rightType == PatMatCapture::type()) {
266-
doCapture(vm, left, right);
387+
doCapture(vm, left, right, captures);
267388
return true;
268389
} else if (rightType == PatMatConjunction::type()) {
269390
doConjunction(vm, left, right);
@@ -361,13 +482,6 @@ void StructuralDualWalk::undoBindings(VM vm) {
361482
}
362483
}
363484

364-
void StructuralDualWalk::doCapture(VM vm, RichNode value, RichNode capture) {
365-
nativeint index = capture.as<PatMatCapture>().index();
366-
367-
if (index >= 0)
368-
captures[(size_t) index].copy(vm, value);
369-
}
370-
371485
void StructuralDualWalk::doConjunction(VM vm, RichNode value,
372486
RichNode conjunction) {
373487
auto conj = conjunction.as<PatMatConjunction>();

vm/vm/main/unify.hh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,22 @@ void WalkStack::clear(VM vm) {
6565
// Global routines //
6666
/////////////////////
6767

68+
enum class PatternMatchResult {
69+
failed,
70+
succeed,
71+
unknown,
72+
};
73+
6874
void fullUnify(VM vm, RichNode left, RichNode right);
6975

7076
bool fullEquals(VM vm, RichNode left, RichNode right);
7177

7278
bool fullPatternMatch(VM vm, RichNode value, RichNode pattern,
7379
StaticArray<UnstableNode> captures);
7480

81+
PatternMatchResult quickPatternMatch(VM vm, RichNode value, RichNode pattern,
82+
StaticArray<UnstableNode> captures);
83+
7584
#ifndef MOZART_GENERATOR
7685

7786
void unify(VM vm, RichNode left, RichNode right) {
@@ -131,7 +140,14 @@ bool equals(VM vm, RichNode left, RichNode right) {
131140

132141
bool patternMatch(VM vm, RichNode value, RichNode pattern,
133142
StaticArray<UnstableNode> captures) {
134-
return fullPatternMatch(vm, value, pattern, captures);
143+
switch (quickPatternMatch(vm, value, pattern, captures)) {
144+
case PatternMatchResult::failed:
145+
return false;
146+
case PatternMatchResult::succeed:
147+
return true;
148+
default: // i.e. PatternMatchResult::unknown
149+
return fullPatternMatch(vm, value, pattern, captures);
150+
}
135151
}
136152

137153
#endif // MOZART_GENERATOR

0 commit comments

Comments
 (0)