Skip to content

Commit 1ff0837

Browse files
make slot that ndv > LOW_NDV_THRESHOLD can be chosen as shuffle key in decomposeRepeat (#60610)
### What problem does this PR solve? Issue Number: close #xxx Related PR: #xxx Problem Summary: ### Release note None ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [ ] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [ ] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [ ] No. - [ ] Yes. <!-- Add document PR link here. eg: apache/doris-website#1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into -->
1 parent d7153c0 commit 1ff0837

File tree

9 files changed

+334
-82
lines changed

9 files changed

+334
-82
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,8 @@ public Statistics getChildStatistics(int index) {
8383
public StatementContext getStatementContext() {
8484
return connectContext.getStatementContext();
8585
}
86+
87+
public ConnectContext getConnectContext() {
88+
return connectContext;
89+
}
8690
}

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,31 +116,7 @@ public List<List<PhysicalProperties>> visitPhysicalHashAggregate(
116116
if (agg.getGroupByExpressions().isEmpty() && agg.getOutputExpressions().isEmpty()) {
117117
return ImmutableList.of();
118118
}
119-
// If the origin attribute satisfies the group by key but does not meet the requirements, ban the plan.
120-
// e.g. select count(distinct a) from t group by b;
121-
// requiredChildProperty: a
122-
// but the child is already distributed by b
123-
// ban this plan
124-
PhysicalProperties originChildProperty = originChildrenProperties.get(0);
125119
PhysicalProperties requiredChildProperty = requiredProperties.get(0);
126-
PhysicalProperties hashSpec = PhysicalProperties.createHash(agg.getGroupByExpressions(), ShuffleType.REQUIRE);
127-
GroupExpression child = children.get(0);
128-
if (child.getPlan() instanceof PhysicalDistribute) {
129-
PhysicalProperties properties = new PhysicalProperties(
130-
DistributionSpecAny.INSTANCE, originChildProperty.getOrderSpec());
131-
Optional<Pair<Cost, GroupExpression>> pair = child.getOwnerGroup().getLowestCostPlan(properties);
132-
// add null check
133-
if (!pair.isPresent()) {
134-
return ImmutableList.of();
135-
}
136-
GroupExpression distributeChild = pair.get().second;
137-
PhysicalProperties distributeChildProperties = distributeChild.getOutputProperties(properties);
138-
if (distributeChildProperties.satisfy(hashSpec)
139-
&& !distributeChildProperties.satisfy(requiredChildProperty)) {
140-
return ImmutableList.of();
141-
}
142-
}
143-
144120
if (!agg.getAggregateParam().canBeBanned) {
145121
return visit(agg, context);
146122
}

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
468468
Set<ExprId> intersectId = Sets.intersection(new HashSet<>(parentHashExprIds),
469469
new HashSet<>(groupByExprIds));
470470
if (!intersectId.isEmpty() && intersectId.size() < groupByExprIds.size()) {
471-
if (shouldUseParent(parentHashExprIds, agg)) {
471+
if (shouldUseParent(parentHashExprIds, agg, context)) {
472472
addRequestPropertyToChildren(PhysicalProperties.createHash(
473473
Utils.fastToImmutableList(intersectId), ShuffleType.REQUIRE));
474474
}
@@ -482,7 +482,11 @@ public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
482482
return null;
483483
}
484484

485-
private boolean shouldUseParent(List<ExprId> parentHashExprIds, PhysicalHashAggregate<? extends Plan> agg) {
485+
private boolean shouldUseParent(List<ExprId> parentHashExprIds, PhysicalHashAggregate<? extends Plan> agg,
486+
PlanContext context) {
487+
if (!context.getConnectContext().getSessionVariable().aggShuffleUseParentKey) {
488+
return false;
489+
}
486490
Optional<GroupExpression> groupExpression = agg.getGroupExpression();
487491
if (!groupExpression.isPresent()) {
488492
return true;

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@
5454
import org.apache.doris.qe.ConnectContext;
5555
import org.apache.doris.statistics.ColumnStatistic;
5656
import org.apache.doris.statistics.Statistics;
57+
import org.apache.doris.statistics.util.StatisticsUtil;
5758

5859
import com.google.common.collect.ImmutableList;
5960
import com.google.common.collect.ImmutableSet;
6061

6162
import java.util.ArrayList;
62-
import java.util.Comparator;
6363
import java.util.HashMap;
6464
import java.util.HashSet;
6565
import java.util.List;
@@ -407,11 +407,7 @@ private int canOptimize(LogicalAggregate<? extends Plan> aggregate, ConnectConte
407407
if (groupingSets.size() <= connectContext.getSessionVariable().decomposeRepeatThreshold) {
408408
return -1;
409409
}
410-
int maxGroupIndex = findMaxGroupingSetIndex(groupingSets);
411-
if (maxGroupIndex < 0) {
412-
return -1;
413-
}
414-
return maxGroupIndex;
410+
return findMaxGroupingSetIndex(groupingSets);
415411
}
416412

417413
/**
@@ -436,6 +432,9 @@ private int findMaxGroupingSetIndex(List<List<Expression>> groupingSets) {
436432
maxGroupIndex = i;
437433
}
438434
}
435+
if (groupingSets.get(maxGroupIndex).isEmpty()) {
436+
return -1;
437+
}
439438
// Second pass: verify that the max-size grouping set contains all other grouping sets
440439
ImmutableSet<Expression> maxGroup = ImmutableSet.copyOf(groupingSets.get(maxGroupIndex));
441440
for (int i = 0; i < groupingSets.size(); ++i) {
@@ -520,32 +519,36 @@ private Optional<List<Expression>> choosePreAggShuffleKeyPartitionExprs(
520519
switch (repeat.getRepeatType()) {
521520
case CUBE:
522521
// Prefer larger NDV to improve balance
523-
chosen = chooseByNdv(maxGroupByList, inputStats, totalInstanceNum);
522+
chosen = chooseOneBalancedKey(maxGroupByList, inputStats, totalInstanceNum);
524523
break;
525524
case GROUPING_SETS:
526525
chosen = chooseByAppearanceThenNdv(repeat.getGroupingSets(), maxGroupIndex, maxGroupByList,
527526
inputStats, totalInstanceNum);
528527
break;
529528
case ROLLUP:
530-
chosen = chooseByRollupPrefixThenNdv(maxGroupByList, inputStats, totalInstanceNum);
529+
chosen = chooseOneBalancedKey(maxGroupByList, inputStats, totalInstanceNum);
531530
break;
532531
default:
533532
chosen = Optional.empty();
534533
}
535534
return chosen.map(ImmutableList::of);
536535
}
537536

538-
private Optional<Expression> chooseByNdv(List<Expression> candidates, Statistics inputStats, int totalInstanceNum) {
537+
private Optional<Expression> chooseOneBalancedKey(List<Expression> candidates, Statistics inputStats,
538+
int totalInstanceNum) {
539539
if (inputStats == null) {
540540
return Optional.empty();
541541
}
542-
Comparator<Expression> cmp = Comparator.comparingDouble(e -> estimateNdv(e, inputStats));
543-
Optional<Expression> choose = candidates.stream().max(cmp);
544-
if (choose.isPresent() && estimateNdv(choose.get(), inputStats) > totalInstanceNum) {
545-
return choose;
546-
} else {
547-
return Optional.empty();
542+
for (Expression candidate : candidates) {
543+
ColumnStatistic columnStatistic = inputStats.findColumnStatistics(candidate);
544+
if (columnStatistic == null || columnStatistic.isUnKnown()) {
545+
continue;
546+
}
547+
if (StatisticsUtil.isBalanced(columnStatistic, inputStats.getRowCount(), totalInstanceNum)) {
548+
return Optional.of(candidate);
549+
}
548550
}
551+
return Optional.empty();
549552
}
550553

551554
/**
@@ -568,42 +571,17 @@ private Optional<Expression> chooseByAppearanceThenNdv(List<List<Expression>> gr
568571
}
569572
}
570573
}
571-
Map<Integer, List<Expression>> countToCandidate = new TreeMap<>();
574+
TreeMap<Integer, List<Expression>> countToCandidate = new TreeMap<>();
572575
for (Map.Entry<Expression, Integer> entry : appearCount.entrySet()) {
573576
countToCandidate.computeIfAbsent(entry.getValue(), v -> new ArrayList<>()).add(entry.getKey());
574577
}
575-
for (Map.Entry<Integer, List<Expression>> entry : countToCandidate.entrySet()) {
576-
Optional<Expression> chosen = chooseByNdv(entry.getValue(), inputStats, totalInstanceNum);
578+
for (Map.Entry<Integer, List<Expression>> entry : countToCandidate.descendingMap().entrySet()) {
579+
Optional<Expression> chosen = chooseOneBalancedKey(entry.getValue(), inputStats, totalInstanceNum);
577580
if (chosen.isPresent()) {
578581
return chosen;
579582
}
580583
}
581584
return Optional.empty();
582-
583-
}
584-
585-
/**
586-
* ROLLUP: prefer earliest prefix key; if NDV is too low, fallback to next prefix.
587-
*/
588-
private Optional<Expression> chooseByRollupPrefixThenNdv(List<Expression> candidates, Statistics inputStats,
589-
int totalInstanceNum) {
590-
for (Expression c : candidates) {
591-
if (estimateNdv(c, inputStats) >= totalInstanceNum) {
592-
return Optional.of(c);
593-
}
594-
}
595-
return Optional.empty();
596-
}
597-
598-
private double estimateNdv(Expression expr, Statistics stats) {
599-
if (stats == null) {
600-
return -1D;
601-
}
602-
ColumnStatistic col = stats.findColumnStatistics(expr);
603-
if (col == null || col.isUnKnown()) {
604-
return -1D;
605-
}
606-
return col.ndv;
607585
}
608586

609587
/**

fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,8 @@ public class SessionVariable implements Serializable, Writable {
839839
public static final String SKEW_REWRITE_JOIN_SALT_EXPLODE_FACTOR = "skew_rewrite_join_salt_explode_factor";
840840

841841
public static final String SKEW_REWRITE_AGG_BUCKET_NUM = "skew_rewrite_agg_bucket_num";
842+
public static final String AGG_SHUFFLE_USE_PARENT_KEY = "agg_shuffle_use_parent_key";
843+
842844
public static final String DECOMPOSE_REPEAT_THRESHOLD = "decompose_repeat_threshold";
843845
public static final String DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP
844846
= "decompose_repeat_shuffle_index_in_max_group";
@@ -850,6 +852,7 @@ public class SessionVariable implements Serializable, Writable {
850852
+ "proportion as hot values, up to HOT_VALUE_COLLECT_COUNT."})
851853
public int hotValueCollectCount = 10; // Select the values that account for at least 10% of the column
852854

855+
853856
public void setHotValueCollectCount(int count) {
854857
this.hotValueCollectCount = count;
855858
}
@@ -2791,6 +2794,12 @@ public static boolean isEagerAggregationOnJoin() {
27912794
}, checker = "checkSkewRewriteAggBucketNum")
27922795
public int skewRewriteAggBucketNum = 1024;
27932796

2797+
@VariableMgr.VarAttr(name = AGG_SHUFFLE_USE_PARENT_KEY, description = {
2798+
"在聚合算子进行 shuffle 时,是否使用父节点的分组键进行 shuffle",
2799+
"Whether to use the parent node's grouping key for shuffling during the aggregation operator"
2800+
}, needForward = false)
2801+
public boolean aggShuffleUseParentKey = true;
2802+
27942803
@VariableMgr.VarAttr(name = ENABLE_PREFER_CACHED_ROWSET, needForward = false,
27952804
description = {"是否启用 prefer cached rowset 功能",
27962805
"Whether to enable prefer cached rowset feature"})

fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.apache.doris.nereids.trees.expressions.literal.TimestampTzLiteral;
6565
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
6666
import org.apache.doris.nereids.types.DataType;
67+
import org.apache.doris.nereids.util.AggregateUtils;
6768
import org.apache.doris.qe.AuditLogHelper;
6869
import org.apache.doris.qe.AutoCloseConnectContext;
6970
import org.apache.doris.qe.ConnectContext;
@@ -1322,4 +1323,23 @@ public static LinkedHashMap<Literal, Float> getHotValues(String stringValues, Ty
13221323
}
13231324
return null;
13241325
}
1326+
1327+
public static boolean isBalanced(ColumnStatistic columnStatistic, double rowCount, int instanceNum) {
1328+
double ndv = columnStatistic.ndv;
1329+
double maxHotValueCntIncludeNull;
1330+
Map<Literal, Float> hotValues = columnStatistic.getHotValues();
1331+
// When hotValues not exist, or exist but unknown, treat nulls as the only hot value.
1332+
if (columnStatistic.getHotValues() == null || hotValues.isEmpty()) {
1333+
maxHotValueCntIncludeNull = columnStatistic.numNulls;
1334+
} else {
1335+
double rate = hotValues.values().stream().mapToDouble(Float::doubleValue).max().orElse(0);
1336+
maxHotValueCntIncludeNull = rate * rowCount > columnStatistic.numNulls
1337+
? rate * rowCount : columnStatistic.numNulls;
1338+
}
1339+
double rowsPerInstance = (rowCount - maxHotValueCntIncludeNull) / instanceNum;
1340+
double balanceFactor = maxHotValueCntIncludeNull == 0
1341+
? Double.MAX_VALUE : rowsPerInstance / maxHotValueCntIncludeNull;
1342+
// The larger this factor is, the more balanced the data.
1343+
return balanceFactor > 2.0 && ndv > instanceNum * 3 && ndv > AggregateUtils.LOW_NDV_THRESHOLD;
1344+
}
13251345
}

fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,4 +368,96 @@ void testWindowWithNoPartitionKeyAndNoOrderKey() {
368368
expected.add(Lists.newArrayList(PhysicalProperties.GATHER));
369369
Assertions.assertEquals(expected, actual);
370370
}
371+
372+
@Test
373+
void testAggregateWithAggShuffleUseParentKeyDisabled() {
374+
// Create ConnectContext with aggShuffleUseParentKey = false
375+
ConnectContext testConnectContext = new ConnectContext();
376+
testConnectContext.getSessionVariable().aggShuffleUseParentKey = false;
377+
378+
SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of());
379+
SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of());
380+
PhysicalHashAggregate<GroupPlan> aggregate = new PhysicalHashAggregate<>(
381+
Lists.newArrayList(key1, key2),
382+
Lists.newArrayList(key1, key2),
383+
new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
384+
true,
385+
logicalProperties,
386+
groupPlan
387+
);
388+
GroupExpression groupExpression = new GroupExpression(aggregate);
389+
new Group(null, groupExpression, null);
390+
391+
// Create a parent hash distribution with key1 only
392+
PhysicalProperties parentProperties = PhysicalProperties.createHash(
393+
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
394+
395+
new Expectations() {
396+
{
397+
jobContext.getRequiredProperties();
398+
result = parentProperties;
399+
}
400+
};
401+
402+
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext);
403+
List<List<PhysicalProperties>> actual
404+
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
405+
406+
// When aggShuffleUseParentKey is false, should only use all groupByExpressions (key1, key2)
407+
// and not use parent key (key1) separately
408+
List<List<PhysicalProperties>> expected = Lists.newArrayList();
409+
expected.add(Lists.newArrayList(PhysicalProperties.createHash(
410+
Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE)));
411+
Assertions.assertEquals(1, actual.size());
412+
Assertions.assertEquals(expected, actual);
413+
}
414+
415+
@Test
416+
void testAggregateWithAggShuffleUseParentKeyEnabled() {
417+
// Create ConnectContext with aggShuffleUseParentKey = true (default value)
418+
ConnectContext testConnectContext = new ConnectContext();
419+
testConnectContext.getSessionVariable().aggShuffleUseParentKey = true;
420+
421+
SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of());
422+
SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of());
423+
PhysicalHashAggregate<GroupPlan> aggregate = new PhysicalHashAggregate<>(
424+
Lists.newArrayList(key1, key2),
425+
Lists.newArrayList(key1, key2),
426+
new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
427+
true,
428+
logicalProperties,
429+
groupPlan
430+
);
431+
GroupExpression groupExpression = new GroupExpression(aggregate);
432+
new Group(null, groupExpression, null);
433+
434+
// Create a parent hash distribution with key1 only
435+
PhysicalProperties parentProperties = PhysicalProperties.createHash(
436+
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
437+
438+
new Expectations() {
439+
{
440+
jobContext.getRequiredProperties();
441+
result = parentProperties;
442+
}
443+
};
444+
new MockUp<org.apache.doris.nereids.memo.GroupExpression>() {
445+
@mockit.Mock
446+
org.apache.doris.statistics.Statistics childStatistics(int idx) {
447+
return null;
448+
}
449+
};
450+
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext);
451+
List<List<PhysicalProperties>> actual
452+
= requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
453+
454+
// When aggShuffleUseParentKey is true, shouldUseParent may return true
455+
// If shouldUseParent returns true, it will add parent key (key1) first, then all groupByExpressions (key1, key2)
456+
Assertions.assertEquals(2, actual.size(), "Should have at least one property request");
457+
PhysicalProperties parentProp = PhysicalProperties.createHash(
458+
Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
459+
PhysicalProperties aggProp = PhysicalProperties.createHash(
460+
Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE);
461+
Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp)) && actual.contains(ImmutableList.of(parentProp)));
462+
}
371463
}

0 commit comments

Comments
 (0)