5454import com.google.common.collect.HashMultimap;
5555import com.google.common.collect.ImmutableList;
5656import com.google.common.collect.Multimap;
57+ import com.google.common.collect.Sets;
5758
5859import org.checkerframework.checker.nullness.qual.Nullable;
5960
6061import java.util.ArrayDeque;
6162import java.util.ArrayList;
6263import java.util.Collection;
64+ import java.util.Deque;
6365import java.util.HashMap;
6466import java.util.HashSet;
67+ import java.util.IdentityHashMap;
6568import java.util.Iterator;
6669import java.util.LinkedHashSet;
6770import java.util.List;
@@ -104,6 +107,8 @@ public class HepPlanner extends AbstractRelOptPlanner {
104107
105108 private final boolean noDag;
106109
110+ private boolean largePlanMode = false;
111+
107112 /**
108113 * Query graph, with edges directed from parent to child. This is a
109114 * single-rooted DAG, possibly with additional roots corresponding to
@@ -183,10 +188,48 @@ public HepPlanner(
183188 this.noDag = noDag;
184189 }
185190
191+ /**
192+ * Create a new {@code HepPlanner} capable of execute multiple HepPrograms
193+ * with (noDag = false, isLargePlanMode = true, enableFiredRulesCache = true).
194+ *
195+ * <p>Unlike planners that require setRoot for every optimization pass,
196+ * this planner preserves the internal graph structure and optimized plan across
197+ * successive executions. This allows for multi-phrase optimization where the
198+ * output of one {@link HepProgram} serves as the immediate starting point for the next.
199+ *
200+ * <p><b>Usage Example:</b>
201+ * <pre>{@code
202+ * HepPlanner planner = new HepPlanner();
203+ * planner.setRoot(initPlanRoot);
204+ * planner.executeProgram(phrase1Program);
205+ * planner.dumpRuleAttemptsInfo(); // optional
206+ * planner.clear(); // clear the rules and rule match caches, the graph is reused
207+ * // other logics ...
208+ * planner.executeProgram(phrase2Program);
209+ * planner.clear();
210+ * ...
211+ * RelNode optimized = planner.buildFinalPlan();
212+ * }</pre>
213+ *
214+ * @see #setRoot(RelNode)
215+ * @see #executeProgram(HepProgram)
216+ * @see #dumpRuleAttemptsInfo()
217+ * @see #clear()
218+ * @see #buildFinalPlan()
219+ */
220+ public HepPlanner() {
221+ this(HepProgram.builder().build(), null, false, null, RelOptCostImpl.FACTORY);
222+ this.largePlanMode = true;
223+ this.enableFiredRulesCache = true;
224+ }
225+
186226 //~ Methods ----------------------------------------------------------------
187227
188228 @Override public void setRoot(RelNode rel) {
189- root = addRelToGraph(rel);
229+ // initRelToVertexCache is used to quickly skip common nodes before traversing its inputs
230+ IdentityHashMap<RelNode, HepRelVertex> initRelToVertexCache = (isLargePlanMode() && !noDag)
231+ ? new IdentityHashMap<>() : null;
232+ root = addRelToGraph(rel, initRelToVertexCache);
190233 dumpGraph();
191234 }
192235
@@ -204,6 +247,14 @@ public HepPlanner(
204247 this.firedRulesCacheIndex.clear();
205248 }
206249
250+ public boolean isLargePlanMode() {
251+ return largePlanMode;
252+ }
253+
254+ public void setLargePlanMode(final boolean largePlanMode) {
255+ this.largePlanMode = largePlanMode;
256+ }
257+
207258 @Override public RelNode changeTraits(RelNode rel, RelTraitSet toTraits) {
208259 // Ignore traits, except for the root, where we remember
209260 // what the final conversion should be.
@@ -214,6 +265,11 @@ public HepPlanner(
214265 }
215266
216267 @Override public RelNode findBestExp() {
268+ if (isLargePlanMode()) {
269+ throw new UnsupportedOperationException("findBestExp is not supported in large plan mode"
270+ + ", please use buildFinalPlan() to get the final plan.");
271+ }
272+
217273 requireNonNull(root, "'root' must not be null");
218274
219275 executeProgram(mainProgram);
@@ -224,6 +280,10 @@ public HepPlanner(
224280 return buildFinalPlan(requireNonNull(root, "'root' must not be null"));
225281 }
226282
283+ public RelNode buildFinalPlan() {
284+ return buildFinalPlan(requireNonNull(root, "'root' must not be null"));
285+ }
286+
227287 /**
228288 * Enables or disables the fire-rule cache.
229289 *
@@ -237,7 +297,7 @@ public void setEnableFiredRulesCache(boolean enable) {
237297
238298 /** Top-level entry point for a program. Initializes state and then invokes
239299 * the program. */
240- private void executeProgram(HepProgram program) {
300+ public void executeProgram(HepProgram program) {
241301 final HepInstruction.PrepareContext px =
242302 HepInstruction.PrepareContext.create(this);
243303 final HepState state = program.prepare(px);
@@ -249,7 +309,7 @@ void executeProgram(HepProgram instruction, HepProgram.State state) {
249309 state.instructionStates.forEach(instructionState -> {
250310 instructionState.execute();
251311 int delta = nTransformations - nTransformationsLastGC;
252- if (delta > graphSizeLastGC) {
312+ if (!isLargePlanMode() && delta > graphSizeLastGC) {
253313 // The number of transformations performed since the last
254314 // garbage collection is greater than the number of vertices in
255315 // the graph at that time. That means there should be a
@@ -492,12 +552,23 @@ private Iterator<HepRelVertex> getGraphIterator(
492552 HepProgram.State programState, HepRelVertex start) {
493553 switch (requireNonNull(programState.matchOrder, "programState.matchOrder")) {
494554 case ARBITRARY:
555+ if (isLargePlanMode()) {
556+ return BreadthFirstIterator.of(graph, start).iterator();
557+ }
558+ return DepthFirstIterator.of(graph, start).iterator();
495559 case DEPTH_FIRST:
560+ if (isLargePlanMode()) {
561+ throw new UnsupportedOperationException("DepthFirstIterator is too slow for large plan mode"
562+ + ", please setLargePlanMode(false) if you don't want to use this mode.");
563+ }
496564 return DepthFirstIterator.of(graph, start).iterator();
497565 case TOP_DOWN:
498566 case BOTTOM_UP:
499567 assert start == root;
500- collectGarbage();
568+ if (!isLargePlanMode()) {
569+ // NOTE: Planner already run GC for every transformation removed subtree
570+ collectGarbage();
571+ }
501572 return TopologicalOrderIterator.of(graph, programState.matchOrder).iterator();
502573 default:
503574 throw new
@@ -774,7 +845,8 @@ private HepRelVertex applyTransformationResults(
774845 parents.add(parent);
775846 }
776847
777- HepRelVertex newVertex = addRelToGraph(bestRel);
848+ HepRelVertex newVertex = addRelToGraph(bestRel, null);
849+ Set<HepRelVertex> garbageVertexSet = new LinkedHashSet<>();
778850
779851 // There's a chance that newVertex is the same as one
780852 // of the parents due to common subexpression recognition
@@ -785,10 +857,12 @@ private HepRelVertex applyTransformationResults(
785857 if (iParentMatch != -1) {
786858 newVertex = parents.get(iParentMatch);
787859 } else {
788- contractVertices(newVertex, vertex, parents);
860+ contractVertices(newVertex, vertex, parents, garbageVertexSet );
789861 }
790862
791- if (getListener() != null) {
863+ if (isLargePlanMode()) {
864+ collectGarbage(garbageVertexSet);
865+ } else if (getListener() != null) {
792866 // Assume listener doesn't want to see garbage.
793867 collectGarbage();
794868 }
@@ -824,19 +898,26 @@ private HepRelVertex applyTransformationResults(
824898 }
825899
826900 private HepRelVertex addRelToGraph(
827- RelNode rel) {
901+ RelNode rel, @Nullable IdentityHashMap<RelNode, HepRelVertex> initRelToVertexCache ) {
828902 // Check if a transformation already produced a reference
829903 // to an existing vertex.
830904 if (graph.vertexSet().contains(rel)) {
831905 return (HepRelVertex) rel;
832906 }
833907
908+ // Fast equiv vertex for set root, before add children.
909+ if (initRelToVertexCache != null && initRelToVertexCache.containsKey(rel)) {
910+ HepRelVertex vertex = initRelToVertexCache.get(rel);
911+ assert vertex != null;
912+ return vertex;
913+ }
914+
834915 // Recursively add children, replacing this rel's inputs
835916 // with corresponding child vertices.
836917 final List<RelNode> inputs = rel.getInputs();
837918 final List<RelNode> newInputs = new ArrayList<>();
838919 for (RelNode input1 : inputs) {
839- HepRelVertex childVertex = addRelToGraph(input1);
920+ HepRelVertex childVertex = addRelToGraph(input1, initRelToVertexCache );
840921 newInputs.add(childVertex);
841922 }
842923
@@ -868,14 +949,19 @@ private HepRelVertex addRelToGraph(
868949 graph.addEdge(newVertex, (HepRelVertex) input);
869950 }
870951
952+ if (initRelToVertexCache != null) {
953+ initRelToVertexCache.put(rel, newVertex);
954+ }
955+
871956 nTransformations++;
872957 return newVertex;
873958 }
874959
875960 private void contractVertices(
876961 HepRelVertex preservedVertex,
877962 HepRelVertex discardedVertex,
878- List<HepRelVertex> parents) {
963+ List<HepRelVertex> parents,
964+ Set<HepRelVertex> garbageVertexSet) {
879965 if (preservedVertex == discardedVertex) {
880966 // Nop.
881967 return;
@@ -897,17 +983,32 @@ private void contractVertices(
897983 }
898984 clearCache(parent);
899985 graph.removeEdge(parent, discardedVertex);
986+
987+ if (!noDag && isLargePlanMode()) {
988+ // Recursive merge parent path
989+ HepRelVertex addedVertex = mapDigestToVertex.get(parentRel.getRelDigest());
990+ if (addedVertex != null && addedVertex != parent) {
991+ List<HepRelVertex> parentCopy = // contractVertices will change predecessorList
992+ new ArrayList<>(Graphs.predecessorListOf(graph, parent));
993+ contractVertices(addedVertex, parent, parentCopy, garbageVertexSet);
994+ continue;
995+ }
996+ }
997+
900998 graph.addEdge(parent, preservedVertex);
901999 updateVertex(parent, parentRel);
9021000 }
9031001
9041002 // NOTE: we don't actually do graph.removeVertex(discardedVertex),
9051003 // because it might still be reachable from preservedVertex.
9061004 // Leave that job for garbage collection.
1005+ // If isLargePlanMode is true, we will do fine grant GC in tryCleanVertices
1006+ // by tracking discarded vertex subtree's inward references.
9071007
9081008 if (discardedVertex == root) {
9091009 root = preservedVertex;
9101010 }
1011+ garbageVertexSet.add(discardedVertex);
9111012 }
9121013
9131014 /**
@@ -992,6 +1093,58 @@ private RelNode buildFinalPlan(HepRelVertex vertex) {
9921093 return rel;
9931094 }
9941095
1096+ /** Try remove discarded vertices recursively. */
1097+ private void tryCleanVertices(HepRelVertex vertex) {
1098+ if (vertex == root || !graph.vertexSet().contains(vertex)
1099+ || !graph.getInwardEdges(vertex).isEmpty()) {
1100+ return;
1101+ }
1102+
1103+ // rel is the no inward edges subtree root.
1104+ RelNode rel = vertex.getCurrentRel();
1105+ notifyDiscard(rel);
1106+
1107+ Set<HepRelVertex> outVertices = new LinkedHashSet<>();
1108+ List<DefaultEdge> outEdges = graph.getOutwardEdges(vertex);
1109+ for (DefaultEdge outEdge : outEdges) {
1110+ outVertices.add((HepRelVertex) outEdge.target);
1111+ }
1112+
1113+ for (HepRelVertex child : outVertices) {
1114+ graph.removeEdge(vertex, child);
1115+ }
1116+ assert graph.getInwardEdges(vertex).isEmpty();
1117+ assert graph.getOutwardEdges(vertex).isEmpty();
1118+ graph.vertexSet().remove(vertex);
1119+ mapDigestToVertex.remove(rel.getRelDigest());
1120+
1121+ for (HepRelVertex child : outVertices) {
1122+ tryCleanVertices(child);
1123+ }
1124+ clearCache(vertex);
1125+
1126+ if (enableFiredRulesCache) {
1127+ for (List<Integer> relIds : firedRulesCacheIndex.get(rel.getId())) {
1128+ firedRulesCache.removeAll(relIds);
1129+ }
1130+ }
1131+ }
1132+
1133+ private void collectGarbage(final Set<HepRelVertex> garbageVertexSet) {
1134+ for (HepRelVertex vertex : garbageVertexSet) {
1135+ tryCleanVertices(vertex);
1136+ }
1137+
1138+ if (LOGGER.isTraceEnabled()) {
1139+ int currentGraphSize = graph.vertexSet().size();
1140+ collectGarbage();
1141+ int currentGraphSize2 = graph.vertexSet().size();
1142+ if (currentGraphSize != currentGraphSize2) {
1143+ throw new AssertionError("Graph size changed after garbage collection");
1144+ }
1145+ }
1146+ }
1147+
9951148 private void collectGarbage() {
9961149 if (nTransformations == nTransformationsLastGC) {
9971150 // No modifications have taken place since the last gc,
@@ -1061,12 +1214,48 @@ private void assertNoCycles() {
10611214 + cyclicVertices);
10621215 }
10631216
1217+ private void assertGraphConsistent() {
1218+ int liveNum = 0;
1219+ for (HepRelVertex vertex : BreadthFirstIterator.of(graph, requireNonNull(root, "root"))) {
1220+ if (graph.getOutwardEdges(vertex).size()
1221+ != Sets.newHashSet(requireNonNull(vertex, "vertex").getCurrentRel().getInputs()).size()) {
1222+ throw new AssertionError("HepPlanner:outward edge num is different "
1223+ + "with input node num, " + vertex);
1224+ }
1225+ for (DefaultEdge edge : graph.getInwardEdges(vertex)) {
1226+ if (!((HepRelVertex) edge.source).getCurrentRel().getInputs().contains(vertex)) {
1227+ throw new AssertionError("HepPlanner:inward edge target is not in input node list, "
1228+ + vertex);
1229+ }
1230+ }
1231+ liveNum++;
1232+ }
1233+
1234+ Set<RelNode> validSet = new HashSet<>();
1235+ Deque<RelNode> nodes = new ArrayDeque<>();
1236+ nodes.push(requireNonNull(requireNonNull(root, "root").getCurrentRel()));
1237+ while (!nodes.isEmpty()) {
1238+ RelNode node = nodes.pop();
1239+ validSet.add(node);
1240+ for (RelNode input : node.getInputs()) {
1241+ nodes.push(((HepRelVertex) input).getCurrentRel());
1242+ }
1243+ }
1244+
1245+ if (liveNum == validSet.size()) {
1246+ return;
1247+ }
1248+ throw new AssertionError("HepPlanner:Query graph live node num is different with root"
1249+ + " input valid node num, liveNodeNum: " + liveNum + ", validNodeNum: " + validSet.size());
1250+ }
1251+
10641252 private void dumpGraph() {
10651253 if (!LOGGER.isTraceEnabled()) {
10661254 return;
10671255 }
10681256
10691257 assertNoCycles();
1258+ assertGraphConsistent();
10701259
10711260 HepRelVertex root = this.root;
10721261 if (root == null) {
0 commit comments