Skip to content

Commit bf2e1c2

Browse files
authored
[opt](distribution) support bucket shuffle for set operation (#59006)
support bucket shuffle for set operation(union/intersect/except)
1 parent 7d83f41 commit bf2e1c2

36 files changed

+818
-188
lines changed

be/src/pipeline/exec/set_probe_sink_operator.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ class SetProbeSinkOperatorX final : public DataSinkOperatorX<SetProbeSinkLocalSt
8585
_cur_child_id(child_id),
8686
_is_colocate(is_intersect ? tnode.intersect_node.is_colocate
8787
: tnode.except_node.is_colocate),
88-
_partition_exprs(is_intersect ? tnode.intersect_node.result_expr_lists[child_id]
89-
: tnode.except_node.result_expr_lists[child_id]) {}
88+
_partition_exprs(
89+
tnode.__isset.distribute_expr_lists
90+
? tnode.distribute_expr_lists[child_id]
91+
: (is_intersect ? tnode.intersect_node.result_expr_lists[child_id]
92+
: tnode.except_node.result_expr_lists[child_id])) {}
9093

9194
#ifdef BE_TEST
9295
SetProbeSinkOperatorX(int cur_child_id)

be/src/pipeline/exec/set_sink_operator.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ class SetSinkOperatorX final : public DataSinkOperatorX<SetSinkLocalState<is_int
8585
: tnode.except_node.result_expr_lists.size()),
8686
_is_colocate(is_intersect ? tnode.intersect_node.is_colocate
8787
: tnode.except_node.is_colocate),
88-
_partition_exprs(is_intersect ? tnode.intersect_node.result_expr_lists[child_id]
89-
: tnode.except_node.result_expr_lists[child_id]),
88+
_partition_exprs(tnode.__isset.distribute_expr_lists
89+
? tnode.distribute_expr_lists[child_id]
90+
: (is_intersect
91+
? tnode.intersect_node.result_expr_lists[child_id]
92+
: tnode.except_node.result_expr_lists[child_id])),
9093
_runtime_filter_descs(tnode.runtime_filters) {
9194
DCHECK_EQ(child_id, _cur_child_id);
9295
DCHECK_GT(_child_quantity, 1);

be/src/pipeline/pipeline_fragment_context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,12 +1597,14 @@ Status PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
15971597
RETURN_IF_ERROR(_build_operators_for_set_operation_node<true>(
15981598
pool, tnode, descs, op, cur_pipe, parent_idx, child_idx,
15991599
followed_by_shuffled_operator));
1600+
_require_bucket_distribution = tnode.intersect_node.is_colocate;
16001601
break;
16011602
}
16021603
case TPlanNodeType::EXCEPT_NODE: {
16031604
RETURN_IF_ERROR(_build_operators_for_set_operation_node<false>(
16041605
pool, tnode, descs, op, cur_pipe, parent_idx, child_idx,
16051606
followed_by_shuffled_operator));
1607+
_require_bucket_distribution = tnode.except_node.is_colocate;
16061608
break;
16071609
}
16081610
case TPlanNodeType::REPEAT_NODE: {

fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@
183183
import org.apache.doris.planner.DataPartition;
184184
import org.apache.doris.planner.DataStreamSink;
185185
import org.apache.doris.planner.DictionarySink;
186+
import org.apache.doris.planner.DistributionMode;
186187
import org.apache.doris.planner.EmptySetNode;
187188
import org.apache.doris.planner.ExceptNode;
188189
import org.apache.doris.planner.ExchangeNode;
189190
import org.apache.doris.planner.GroupCommitBlockSink;
190191
import org.apache.doris.planner.HashJoinNode;
191-
import org.apache.doris.planner.HashJoinNode.DistributionMode;
192192
import org.apache.doris.planner.HiveTableSink;
193193
import org.apache.doris.planner.IcebergTableSink;
194194
import org.apache.doris.planner.IntersectNode;
@@ -2269,6 +2269,14 @@ && findOlapScanNodesByPassExchangeAndJoinNode(setOperationFragment.getPlanRoot()
22692269
setOperationNode.setColocate(true);
22702270
}
22712271

2272+
for (Plan child : setOperation.children()) {
2273+
PhysicalPlan childPhysicalPlan = (PhysicalPlan) child;
2274+
if (JoinUtils.isStorageBucketed(childPhysicalPlan.getPhysicalProperties())) {
2275+
setOperationNode.setDistributionMode(DistributionMode.BUCKET_SHUFFLE);
2276+
break;
2277+
}
2278+
}
2279+
22722280
return setOperationFragment;
22732281
}
22742282

fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/RuntimeFilterTranslator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
import org.apache.doris.nereids.types.DataType;
3232
import org.apache.doris.planner.CTEScanNode;
3333
import org.apache.doris.planner.DataStreamSink;
34+
import org.apache.doris.planner.DistributionMode;
3435
import org.apache.doris.planner.HashJoinNode;
35-
import org.apache.doris.planner.HashJoinNode.DistributionMode;
3636
import org.apache.doris.planner.JoinNodeBase;
3737
import org.apache.doris.planner.RuntimeFilter.RuntimeFilterTarget;
3838
import org.apache.doris.planner.ScanNode;

fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/PlanPatterns.java

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@
3434
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
3535
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
3636
import org.apache.doris.nereids.trees.plans.physical.PhysicalBinary;
37+
import org.apache.doris.nereids.trees.plans.physical.PhysicalExcept;
38+
import org.apache.doris.nereids.trees.plans.physical.PhysicalIntersect;
3739
import org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf;
3840
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
3941
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnary;
42+
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
4043

4144
import java.util.Arrays;
4245

@@ -222,7 +225,7 @@ default PatternDescriptor<LogicalExcept> logicalExcept() {
222225
}
223226

224227
/**
225-
* create a logicalUnion pattern.
228+
* create a logicalIntersect pattern.
226229
*/
227230
default PatternDescriptor<LogicalIntersect>
228231
logicalIntersect(
@@ -323,4 +326,69 @@ default PatternDescriptor<InlineTable> inlineTable() {
323326
default PatternDescriptor<OneRowRelation> oneRowRelation() {
324327
return new PatternDescriptor(new TypePattern(OneRowRelation.class), defaultPromise());
325328
}
329+
330+
/**
331+
* create a physicalUnion multi.
332+
*/
333+
default PatternDescriptor<PhysicalUnion> physicalUnion(
334+
PatternDescriptor... children) {
335+
return new PatternDescriptor(
336+
new TypePattern(PhysicalUnion.class,
337+
Arrays.stream(children)
338+
.map(PatternDescriptor::getPattern)
339+
.toArray(Pattern[]::new)),
340+
defaultPromise());
341+
}
342+
343+
/**
344+
* create a physicalUnion multi.
345+
*/
346+
default PatternDescriptor<PhysicalUnary> physicalUnion() {
347+
return new PatternDescriptor(
348+
new TypePattern(PhysicalUnary.class, multi().pattern),
349+
defaultPromise());
350+
}
351+
352+
/**
353+
* create a physicalExcept pattern.
354+
*/
355+
default PatternDescriptor<PhysicalExcept> physicalExcept(PatternDescriptor... children) {
356+
return new PatternDescriptor(
357+
new TypePattern(PhysicalExcept.class,
358+
Arrays.stream(children)
359+
.map(PatternDescriptor::getPattern)
360+
.toArray(Pattern[]::new)),
361+
defaultPromise());
362+
}
363+
364+
/**
365+
* create a physicalExcept multi.
366+
*/
367+
default PatternDescriptor<PhysicalExcept> physicalExcept() {
368+
return new PatternDescriptor(
369+
new TypePattern(PhysicalExcept.class, multi().pattern),
370+
defaultPromise());
371+
}
372+
373+
/**
374+
* create a physicalIntersect multi.
375+
*/
376+
default PatternDescriptor<PhysicalIntersect> physicalIntersect(
377+
PatternDescriptor... children) {
378+
return new PatternDescriptor(
379+
new TypePattern(PhysicalIntersect.class,
380+
Arrays.stream(children)
381+
.map(PatternDescriptor::getPattern)
382+
.toArray(Pattern[]::new)),
383+
defaultPromise());
384+
}
385+
386+
/**
387+
* create a physicalIntersect multi.
388+
*/
389+
default PatternDescriptor<PhysicalIntersect> physicalIntersect() {
390+
return new PatternDescriptor(
391+
new TypePattern(PhysicalIntersect.class, multi().pattern),
392+
defaultPromise());
393+
}
326394
}

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.doris.nereids.trees.expressions.SlotReference;
3131
import org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
3232
import org.apache.doris.nereids.trees.plans.Plan;
33+
import org.apache.doris.nereids.trees.plans.algebra.Union;
3334
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
3435
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
3536
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEAnchor;
@@ -70,8 +71,10 @@
7071
import com.google.common.collect.Maps;
7172
import com.google.common.collect.Sets;
7273

74+
import java.util.ArrayList;
7375
import java.util.Arrays;
7476
import java.util.HashSet;
77+
import java.util.LinkedHashMap;
7578
import java.util.List;
7679
import java.util.Map;
7780
import java.util.Objects;
@@ -445,6 +448,54 @@ public PhysicalProperties visitPhysicalSetOperation(PhysicalSetOperation setOper
445448
if (childrenDistribution.stream().allMatch(DistributionSpecGather.class::isInstance)) {
446449
return PhysicalProperties.GATHER;
447450
}
451+
452+
int distributeToChildIndex
453+
= setOperation.<Integer>getMutableState(PhysicalSetOperation.DISTRIBUTE_TO_CHILD_INDEX).orElse(-1);
454+
if (distributeToChildIndex >= 0
455+
&& childrenDistribution.get(distributeToChildIndex) instanceof DistributionSpecHash) {
456+
DistributionSpecHash childDistribution
457+
= (DistributionSpecHash) childrenDistribution.get(distributeToChildIndex);
458+
List<SlotReference> childToIndex = setOperation.getRegularChildrenOutputs().get(distributeToChildIndex);
459+
Map<ExprId, Integer> idToOutputIndex = new LinkedHashMap<>();
460+
for (int j = 0; j < childToIndex.size(); j++) {
461+
idToOutputIndex.put(childToIndex.get(j).getExprId(), j);
462+
}
463+
464+
List<ExprId> orderedShuffledColumns = childDistribution.getOrderedShuffledColumns();
465+
List<ExprId> setOperationDistributeColumnIds = new ArrayList<>();
466+
for (ExprId tableDistributeColumnId : orderedShuffledColumns) {
467+
Integer index = idToOutputIndex.get(tableDistributeColumnId);
468+
if (index == null) {
469+
break;
470+
}
471+
setOperationDistributeColumnIds.add(setOperation.getOutput().get(index).getExprId());
472+
}
473+
// check whether the set operation output all distribution columns of the child
474+
if (setOperationDistributeColumnIds.size() == orderedShuffledColumns.size()) {
475+
boolean isUnion = setOperation instanceof Union;
476+
boolean shuffleToRight = distributeToChildIndex > 0;
477+
if (!isUnion && shuffleToRight) {
478+
return new PhysicalProperties(
479+
new DistributionSpecHash(
480+
setOperationDistributeColumnIds,
481+
ShuffleType.EXECUTION_BUCKETED
482+
)
483+
);
484+
} else {
485+
// keep the distribution as the child
486+
return new PhysicalProperties(
487+
new DistributionSpecHash(
488+
setOperationDistributeColumnIds,
489+
childDistribution.getShuffleType(),
490+
childDistribution.getTableId(),
491+
childDistribution.getSelectedIndexId(),
492+
childDistribution.getPartitionIds()
493+
)
494+
);
495+
}
496+
}
497+
}
498+
448499
for (int i = 0; i < childrenDistribution.size(); i++) {
449500
DistributionSpec childDistribution = childrenDistribution.get(i);
450501
if (!(childDistribution instanceof DistributionSpecHash)) {
@@ -455,6 +506,7 @@ public PhysicalProperties visitPhysicalSetOperation(PhysicalSetOperation setOper
455506
return new PhysicalProperties(childDistribution);
456507
}
457508
}
509+
458510
DistributionSpecHash distributionSpecHash = (DistributionSpecHash) childDistribution;
459511
int[] offsetsOfCurrentChild = new int[distributionSpecHash.getOrderedShuffledColumns().size()];
460512
for (int j = 0; j < setOperation.getRegularChildOutput(i).size(); j++) {

0 commit comments

Comments
 (0)