Skip to content

Commit c2bed40

Browse files
fix
1 parent 2e5a428 commit c2bed40

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ public boolean isFuncDeps(Set<Slot> dominate, Set<Slot> dependency) {
183183
return items.contains(new FuncDepsItem(dominate, dependency));
184184
}
185185

186+
// 这个也是判断是否为双射的
186187
public boolean isCircleDeps(Set<Slot> dominate, Set<Slot> dependency) {
187188
return items.contains(new FuncDepsItem(dominate, dependency))
188189
&& items.contains(new FuncDepsItem(dependency, dominate));
@@ -201,16 +202,30 @@ public Map<Set<Slot>, Set<Set<Slot>>> getREdges() {
201202
}
202203

203204
/**
204-
* find the determinants of dependencies
205+
* Finds all slot sets that have a bijective relationship with the given slot set.
206+
* Given edges containing:
207+
* {A} -> {{B}, {C}}
208+
* {B} -> {{A}, {D}}
209+
* {C} -> {{A}}
210+
* When slot = {A}, returns {{B}} because {A} and {B} mutually determine each other.
211+
* {C} is not returned because {C} does not determine {A} (one-way dependency only).
205212
*/
206-
public Set<Set<Slot>> findDeterminats(Set<Slot> dependency) {
207-
Set<Set<Slot>> determinants = new HashSet<>();
208-
for (FuncDepsItem item : items) {
209-
if (item.dependencies.equals(dependency)) {
210-
determinants.add(item.determinants);
213+
public Set<Set<Slot>> findBijectionSlots(Set<Slot> slot) {
214+
Set<Set<Slot>> bijectionSlots = new HashSet<>();
215+
if (!edges.containsKey(slot)) {
216+
return bijectionSlots;
217+
}
218+
for (Set<Slot> dep : edges.get(slot)) {
219+
if (!edges.containsKey(dep)) {
220+
continue;
221+
}
222+
for (Set<Slot> det : edges.get(dep)) {
223+
if (det.equals(slot)) {
224+
bijectionSlots.add(dep);
225+
}
211226
}
212227
}
213-
return determinants;
228+
return bijectionSlots;
214229
}
215230

216231
@Override

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import java.util.Map;
4848
import java.util.Map.Entry;
4949
import java.util.Set;
50+
import java.util.stream.Collectors;
5051

5152
/**
5253
* Push down agg through join with foreign key:
@@ -131,13 +132,27 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
131132
if (primary.getOutputSet().stream().noneMatch(aggInputs::contains)) {
132133
return agg;
133134
}
134-
Set<Slot> primaryOutputSet = primary.getOutputSet();
135-
Set<Slot> primarySlots = Sets.intersection(aggInputs, primaryOutputSet);
135+
// Firstly, using fd to eliminate group by key.
136+
// group by primary_table_pk, primary_table_other
137+
// -> group by primary_table_pk
138+
Set<Set<Slot>> groupBySlots = new HashSet<>();
139+
for (Expression expression : agg.getGroupByExpressions()) {
140+
groupBySlots.add(expression.getInputSlots());
141+
}
136142
DataTrait dataTrait = child.getLogicalProperties().getTrait();
137143
FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(foreign.getOutputSet(), primary.getOutputSet()));
144+
Set<Slot> foreignOutput = Sets.intersection(agg.getOutputSet(), foreign.getOutputSet());
145+
Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(groupBySlots, foreignOutput);
146+
List<Slot> minGroupBySlotList = minGroupBySlots.stream().flatMap(Set::stream).collect(Collectors.toList());
147+
148+
// Secondly, put bijective slot into map: {primary_table_pk : foreign_table_fk}
149+
// Bijective slots are mutually interchangeable within GROUP BY keys.
150+
// group by primary_table_pk equals group by foreign_table_fk
151+
Set<Slot> primaryOutputSet = primary.getOutputSet();
152+
Set<Slot> primarySlots = Sets.intersection(aggInputs, primaryOutputSet);
138153
HashMap<Slot, Slot> primaryToForeignDeps = new HashMap<>();
139154
for (Slot slot : primarySlots) {
140-
Set<Set<Slot>> replacedSlotSets = funcDeps.findDeterminats(ImmutableSet.of(slot));
155+
Set<Set<Slot>> replacedSlotSets = funcDeps.findBijectionSlots(ImmutableSet.of(slot));
141156
for (Set<Slot> replacedSlots : replacedSlotSets) {
142157
if (primaryOutputSet.stream().noneMatch(replacedSlots::contains)
143158
&& replacedSlots.size() == 1) {
@@ -147,7 +162,9 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
147162
}
148163
}
149164

150-
Set<Expression> newGroupBySlots = constructNewGroupBy(agg, primaryOutputSet, primaryToForeignDeps);
165+
// Thirdly, construct new Agg below join.
166+
Set<Expression> newGroupBySlots = constructNewGroupBy(minGroupBySlotList, primaryOutputSet,
167+
primaryToForeignDeps);
151168
List<NamedExpression> newOutput = constructNewOutput(
152169
agg, primaryOutputSet, primaryToForeignDeps, funcDeps, primary);
153170
if (newGroupBySlots == null || newOutput == null) {
@@ -156,10 +173,10 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
156173
return agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput));
157174
}
158175

159-
private @Nullable Set<Expression> constructNewGroupBy(LogicalAggregate<?> agg, Set<Slot> primaryOutputs,
160-
Map<Slot, Slot> primaryToForeignBiDeps) {
176+
private @Nullable Set<Expression> constructNewGroupBy(List<? extends Expression> gbyExpression,
177+
Set<Slot> primaryOutputs, Map<Slot, Slot> primaryToForeignBiDeps) {
161178
Set<Expression> newGroupBySlots = new HashSet<>();
162-
for (Expression expression : agg.getGroupByExpressions()) {
179+
for (Expression expression : gbyExpression) {
163180
if (!(expression instanceof Slot)) {
164181
return null;
165182
}

regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/Suite.groovy

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3537,4 +3537,9 @@ class Suite implements GroovyInterceptable {
35373537
GlobalLock.unlock(lockName)
35383538
}
35393539
}
3540+
3541+
def explain_and_result = { tag, sql ->
3542+
"qt_${tag}_shape" "explain shape plan ${sql}"
3543+
"order_qt_${tag}_result" "${sql}"
3544+
}
35403545
}

0 commit comments

Comments
 (0)