Skip to content

Commit 261b923

Browse files
mbasmanovameta-codesync[bot]
authored andcommitted
feat: Add PlanNode::requiresSingleThread() API (#16753)
Summary: Pull Request resolved: #16753 Add a virtual method to PlanNode that returns true if the node requires single-threaded execution (maxDrivers = 1). Override in ValuesNode, ArrowStreamNode, final TopNNode, final LimitNode, final OrderByNode, LocalMergeNode, MergeExchangeNode, MergeJoinNode, TableWriteMergeNode, LocalPartitionNode(Gather), MixedUnionNode, null-aware right semi project HashJoinNode, and TableWriteNode(!supportsMultiThreading). Refactor LocalPlanner::maxDrivers to use the new API instead of per-type dynamic_casts for single-thread checks. Config-dependent logic (writer counts, repartition partition count) remains unchanged. Differential Revision: D96366381
1 parent 22d35f6 commit 261b923

File tree

2 files changed

+73
-72
lines changed

2 files changed

+73
-72
lines changed

velox/core/PlanNode.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,14 @@ class PlanNode : public ISerializable {
206206
return false;
207207
}
208208

209+
/// Returns true if this plan node requires single-threaded execution
210+
/// (maxDrivers = 1). For example, ValuesNode, final OrderByNode, final
211+
/// LimitNode, MergeExchangeNode, LocalMergeNode, and
212+
/// LocalPartitionNode(Gather) all require single-threaded execution.
213+
virtual bool requiresSingleThread() const {
214+
return false;
215+
}
216+
209217
/// Returns true if this plan node operator supports task barrier processing.
210218
/// To support barrier processing, the operator must be able to drain its
211219
/// buffered output when it receives the drain signal at split boundary. Not
@@ -391,6 +399,10 @@ class ValuesNode : public PlanNode {
391399
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
392400
const override;
393401

402+
bool requiresSingleThread() const override {
403+
return !parallelizable_;
404+
}
405+
394406
const std::vector<RowVectorPtr>& values() const {
395407
return values_;
396408
}
@@ -493,6 +505,10 @@ class ArrowStreamNode : public PlanNode {
493505
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
494506
const override;
495507

508+
bool requiresSingleThread() const override {
509+
return true;
510+
}
511+
496512
const std::shared_ptr<ArrowArrayStream>& arrowStream() const {
497513
return arrowStream_;
498514
}
@@ -1713,6 +1729,11 @@ class TableWriteNode : public PlanNode {
17131729
return columnStatsSpec_;
17141730
}
17151731

1732+
bool requiresSingleThread() const override {
1733+
return !insertTableHandle_->connectorInsertTableHandle()
1734+
->supportsMultiThreading();
1735+
}
1736+
17161737
bool canSpill(const QueryConfig& queryConfig) const override {
17171738
return queryConfig.writerSpillEnabled();
17181739
}
@@ -1846,6 +1867,10 @@ class TableWriteMergeNode : public PlanNode {
18461867
return columnStatsSpec_;
18471868
}
18481869

1870+
bool requiresSingleThread() const override {
1871+
return true;
1872+
}
1873+
18491874
const std::vector<PlanNodePtr>& sources() const override {
18501875
return sources_;
18511876
}
@@ -2351,6 +2376,10 @@ class MergeExchangeNode : public ExchangeNode {
23512376
return sortingOrders_;
23522377
}
23532378

2379+
bool requiresSingleThread() const override {
2380+
return true;
2381+
}
2382+
23542383
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
23552384
const override;
23562385

@@ -2449,6 +2478,10 @@ class LocalMergeNode : public PlanNode {
24492478
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
24502479
const override;
24512480

2481+
bool requiresSingleThread() const override {
2482+
return true;
2483+
}
2484+
24522485
const std::vector<FieldAccessTypedExprPtr>& sortingKeys() const {
24532486
return sortingKeys_;
24542487
}
@@ -2660,6 +2693,10 @@ class LocalPartitionNode : public PlanNode {
26602693
return scaleWriter_;
26612694
}
26622695

2696+
bool requiresSingleThread() const override {
2697+
return type_ == Type::kGather;
2698+
}
2699+
26632700
bool supportsBarrier() const override {
26642701
return !scaleWriter_;
26652702
}
@@ -3433,6 +3470,10 @@ class HashJoinNode : public AbstractJoinNode {
34333470
queryConfig.joinSpillEnabled();
34343471
}
34353472

3473+
bool requiresSingleThread() const override {
3474+
return isRightSemiProjectJoin() && nullAware_;
3475+
}
3476+
34363477
bool isNullAware() const {
34373478
return nullAware_;
34383479
}
@@ -3507,6 +3548,10 @@ class MergeJoinNode : public AbstractJoinNode {
35073548
}
35083549
};
35093550

3551+
bool requiresSingleThread() const override {
3552+
return true;
3553+
}
3554+
35103555
std::string_view name() const override {
35113556
return "MergeJoin";
35123557
}
@@ -4084,6 +4129,10 @@ class OrderByNode : public PlanNode {
40844129
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
40854130
const override;
40864131

4132+
bool requiresSingleThread() const override {
4133+
return !isPartial_;
4134+
}
4135+
40874136
// True if this node only sorts a portion of the final result. If it is
40884137
// true, a local merge or merge exchange is required to merge the sorted
40894138
// runs.
@@ -4431,6 +4480,10 @@ class TopNNode : public PlanNode {
44314480
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
44324481
const override;
44334482

4483+
bool requiresSingleThread() const override {
4484+
return !isPartial_;
4485+
}
4486+
44344487
int32_t count() const {
44354488
return count_;
44364489
}
@@ -4566,6 +4619,10 @@ class LimitNode : public PlanNode {
45664619
return count_;
45674620
}
45684621

4622+
bool requiresSingleThread() const override {
4623+
return !isPartial_;
4624+
}
4625+
45694626
bool isPartial() const {
45704627
return isPartial_;
45714628
}
@@ -5843,6 +5900,10 @@ class MixedUnionNode : public PlanNode {
58435900
return batchSizesPerSource_[sourceIndex];
58445901
}
58455902

5903+
bool requiresSingleThread() const override {
5904+
return true;
5905+
}
5906+
58465907
void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context)
58475908
const override;
58485909

velox/exec/LocalPlanner.cpp

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ void plan(
283283
// Sometimes consumer limits the number of drivers its producer can run.
284284
uint32_t maxDriversForConsumer(
285285
const std::shared_ptr<const core::PlanNode>& node) {
286-
if (std::dynamic_pointer_cast<const core::MergeJoinNode>(node)) {
287-
// MergeJoinNode must run single-threaded.
286+
if (node && node->requiresSingleThread()) {
288287
return 1;
289288
}
290289
return std::numeric_limits<uint32_t>::max();
@@ -298,83 +297,24 @@ uint32_t maxDrivers(
298297
return count;
299298
}
300299
for (auto& node : driverFactory.planNodes) {
301-
if (auto topN = std::dynamic_pointer_cast<const core::TopNNode>(node)) {
302-
if (!topN->isPartial()) {
303-
// final topN must run single-threaded
304-
return 1;
305-
}
306-
} else if (
307-
auto values = std::dynamic_pointer_cast<const core::ValuesNode>(node)) {
308-
// values node must run single-threaded, unless in test context
309-
if (!values->testingIsParallelizable()) {
310-
return 1;
311-
}
312-
} else if (std::dynamic_pointer_cast<const core::ArrowStreamNode>(node)) {
313-
// ArrowStream node must run single-threaded.
300+
if (node->requiresSingleThread()) {
314301
return 1;
315-
} else if (
316-
auto limit = std::dynamic_pointer_cast<const core::LimitNode>(node)) {
317-
// final limit must run single-threaded
318-
if (!limit->isPartial()) {
319-
return 1;
320-
}
321-
} else if (
322-
auto orderBy =
323-
std::dynamic_pointer_cast<const core::OrderByNode>(node)) {
324-
// final orderby must run single-threaded
325-
if (!orderBy->isPartial()) {
326-
return 1;
327-
}
328-
} else if (
329-
auto localExchange =
302+
}
303+
304+
if (auto localExchange =
330305
std::dynamic_pointer_cast<const core::LocalPartitionNode>(node)) {
331-
// Local gather must run single-threaded.
332-
switch (localExchange->type()) {
333-
case core::LocalPartitionNode::Type::kGather:
334-
return 1;
335-
case core::LocalPartitionNode::Type::kRepartition:
336-
count = std::min(queryConfig.maxLocalExchangePartitionCount(), count);
337-
break;
338-
default:
339-
VELOX_UNREACHABLE("Unexpected local exchange type");
306+
// Repartition limits parallelism to the partition count.
307+
if (localExchange->type() ==
308+
core::LocalPartitionNode::Type::kRepartition) {
309+
count = std::min(queryConfig.maxLocalExchangePartitionCount(), count);
340310
}
341-
} else if (std::dynamic_pointer_cast<const core::LocalMergeNode>(node)) {
342-
// Local merge must run single-threaded.
343-
return 1;
344-
} else if (std::dynamic_pointer_cast<const core::MixedUnionNode>(node)) {
345-
// Mixed union must run single-threaded.
346-
return 1;
347-
} else if (std::dynamic_pointer_cast<const core::MergeExchangeNode>(node)) {
348-
// Merge exchange must run single-threaded.
349-
return 1;
350-
} else if (std::dynamic_pointer_cast<const core::MergeJoinNode>(node)) {
351-
// Merge join must run single-threaded.
352-
return 1;
353-
} else if (
354-
auto join = std::dynamic_pointer_cast<const core::HashJoinNode>(node)) {
355-
// Null-aware right semi project doesn't support multi-threaded
356-
// execution.
357-
if (join->isRightSemiProjectJoin() && join->isNullAware()) {
358-
return 1;
359-
}
360-
} else if (std::dynamic_pointer_cast<const core::TableWriteMergeNode>(
361-
node)) {
362-
// TableWriteMerge accumulates state (row counts, fragments, stats)
363-
// and produces a single summary row. Must run single-threaded.
364-
return 1;
365311
} else if (
366312
auto tableWrite =
367313
std::dynamic_pointer_cast<const core::TableWriteNode>(node)) {
368-
const auto& connectorInsertHandle =
369-
tableWrite->insertTableHandle()->connectorInsertTableHandle();
370-
if (!connectorInsertHandle->supportsMultiThreading()) {
371-
return 1;
314+
if (tableWrite->hasPartitioningScheme()) {
315+
return queryConfig.taskPartitionedWriterCount();
372316
} else {
373-
if (tableWrite->hasPartitioningScheme()) {
374-
return queryConfig.taskPartitionedWriterCount();
375-
} else {
376-
return queryConfig.taskWriterCount();
377-
}
317+
return queryConfig.taskWriterCount();
378318
}
379319
} else {
380320
auto result = Operator::maxDrivers(node);

0 commit comments

Comments
 (0)