Skip to content

Commit 4162b53

Browse files
authored
feat: Add trace replay support for PartitionAndSerialize operator to presto_cpp (#26500)
Add trace replay support for PartitionAndSerialize operator to presto_cpp ``` == NO RELEASE NOTE == ```
1 parent f492204 commit 4162b53

File tree

10 files changed

+587
-0
lines changed

10 files changed

+587
-0
lines changed

presto-native-execution/presto_cpp/main/PrestoServer.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
#include "velox/dwio/parquet/RegisterParquetWriter.h"
6464
#include "velox/dwio/text/RegisterTextWriter.h"
6565
#include "velox/exec/OutputBufferManager.h"
66+
#include "velox/exec/TraceUtil.h"
6667
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
6768
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
6869
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -437,6 +438,7 @@ void PrestoServer::run() {
437438
registerRemoteFunctions();
438439
registerVectorSerdes();
439440
registerPrestoPlanNodeSerDe();
441+
registerTraceNodeFactories();
440442
registerDynamicFunctions();
441443

442444
facebook::velox::exec::ExchangeSource::registerFactory(
@@ -1805,4 +1807,29 @@ void PrestoServer::reportNodeStats(proxygen::ResponseHandler* downstream) {
18051807

18061808
http::sendOkResponse(downstream, json(nodeStats));
18071809
}
1810+
1811+
void PrestoServer::registerTraceNodeFactories() {
1812+
// Register trace node factory for PartitionAndSerialize operator
1813+
velox::exec::trace::registerTraceNodeFactory(
1814+
"PartitionAndSerialize",
1815+
[](const velox::core::PlanNode* traceNode,
1816+
const velox::core::PlanNodeId& nodeId) -> velox::core::PlanNodePtr {
1817+
if (const auto* partitionAndSerializeNode =
1818+
dynamic_cast<const operators::PartitionAndSerializeNode*>(
1819+
traceNode)) {
1820+
return std::make_shared<operators::PartitionAndSerializeNode>(
1821+
nodeId,
1822+
partitionAndSerializeNode->keys(),
1823+
partitionAndSerializeNode->numPartitions(),
1824+
partitionAndSerializeNode->serializedRowType(),
1825+
std::make_shared<velox::exec::trace::DummySourceNode>(
1826+
partitionAndSerializeNode->sources().front()->outputType()),
1827+
partitionAndSerializeNode->isReplicateNullsAndAny(),
1828+
partitionAndSerializeNode->partitionFunctionFactory(),
1829+
partitionAndSerializeNode->sortingOrders(),
1830+
partitionAndSerializeNode->sortingKeys());
1831+
}
1832+
return nullptr;
1833+
});
1834+
}
18081835
} // namespace facebook::presto

presto-native-execution/presto_cpp/main/PrestoServer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ class PrestoServer {
165165

166166
virtual void registerMemoryArbitrators();
167167

168+
virtual void registerTraceNodeFactories();
169+
168170
/// Invoked after creating global (singleton) config objects (SystemConfig and
169171
/// NodeConfig) and before loading their properties from the file.
170172
/// In the implementation any extra config properties can be registered.

presto-native-execution/presto_cpp/main/operators/tests/ShuffleTest.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,98 @@ class ShuffleTest : public exec::test::OperatorTestBase {
838838
fileSystem->remove(file);
839839
}
840840
}
841+
842+
void runPartitionAndSerializeSerdeTest(
843+
const RowVectorPtr& data,
844+
size_t numPartitions,
845+
const std::optional<std::vector<std::string>>& serdeLayout =
846+
std::nullopt) {
847+
TestShuffleWriter::reset();
848+
849+
auto shuffleInfo = testShuffleInfo(numPartitions, 1 << 20);
850+
TestShuffleWriter::createWriter(shuffleInfo, pool());
851+
852+
auto plan = exec::test::PlanBuilder()
853+
.values({data}, true)
854+
.addNode(addPartitionAndSerializeNode(
855+
numPartitions,
856+
false,
857+
serdeLayout.value_or(std::vector<std::string>{})))
858+
.localPartition(std::vector<std::string>{})
859+
.addNode(addShuffleWriteNode(
860+
numPartitions,
861+
std::string(TestShuffleFactory::kShuffleName),
862+
shuffleInfo))
863+
.planNode();
864+
865+
exec::CursorParameters params;
866+
params.planNode = plan;
867+
params.maxDrivers = 1;
868+
869+
auto [taskCursor, results] = exec::test::readCursor(params);
870+
ASSERT_EQ(results.size(), 0);
871+
872+
auto shuffleWriter = TestShuffleWriter::getInstance();
873+
ASSERT_NE(shuffleWriter, nullptr);
874+
875+
auto readyPartitions = shuffleWriter->readyPartitions();
876+
ASSERT_NE(readyPartitions, nullptr);
877+
878+
size_t totalRows = 0;
879+
for (size_t partitionIdx = 0; partitionIdx < numPartitions;
880+
++partitionIdx) {
881+
for (const auto& batch : (*readyPartitions)[partitionIdx]) {
882+
totalRows += batch->rows.size();
883+
}
884+
}
885+
ASSERT_EQ(totalRows, data->size());
886+
887+
auto expectedType = serdeLayout.has_value()
888+
? createSerdeLayoutType(asRowType(data->type()), serdeLayout.value())
889+
: asRowType(data->type());
890+
891+
std::vector<RowVectorPtr> deserializedData;
892+
for (size_t partitionIdx = 0; partitionIdx < numPartitions;
893+
++partitionIdx) {
894+
for (const auto& batch : (*readyPartitions)[partitionIdx]) {
895+
auto deserialized = std::dynamic_pointer_cast<RowVector>(
896+
row::CompactRow::deserialize(batch->rows, expectedType, pool()));
897+
if (deserialized != nullptr && deserialized->size() > 0) {
898+
deserializedData.push_back(deserialized);
899+
}
900+
}
901+
}
902+
903+
auto expected = serdeLayout.has_value()
904+
? reorderColumns(data, serdeLayout.value())
905+
: data;
906+
velox::exec::test::assertEqualResults({expected}, deserializedData);
907+
}
908+
909+
private:
910+
RowTypePtr createSerdeLayoutType(
911+
const RowTypePtr& originalType,
912+
const std::vector<std::string>& layout) {
913+
std::vector<std::string> names;
914+
std::vector<TypePtr> types;
915+
for (const auto& name : layout) {
916+
auto idx = originalType->getChildIdx(name);
917+
names.push_back(name);
918+
types.push_back(originalType->childAt(idx));
919+
}
920+
return ROW(std::move(names), std::move(types));
921+
}
922+
923+
RowVectorPtr reorderColumns(
924+
const RowVectorPtr& data,
925+
const std::vector<std::string>& newLayout) {
926+
auto rowType = asRowType(data->type());
927+
std::vector<VectorPtr> columns;
928+
for (const auto& name : newLayout) {
929+
columns.push_back(data->childAt(rowType->getChildIdx(name)));
930+
}
931+
return makeRowVector(columns);
932+
}
841933
};
842934

843935
TEST_F(ShuffleTest, operators) {
@@ -1583,6 +1675,23 @@ TEST_F(ShuffleTest, shuffleReadRuntimeStats) {
15831675
ASSERT_EQ(velox::RuntimeCounter::Unit::kNone, numBatchesStat.unit);
15841676
}
15851677
}
1678+
1679+
TEST_F(ShuffleTest, partitionAndSerializeEndToEnd) {
1680+
auto data = makeRowVector({
1681+
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
1682+
makeFlatVector<int64_t>({10, 20, 30, 40, 50, 60}),
1683+
});
1684+
runPartitionAndSerializeSerdeTest(data, 4);
1685+
1686+
data = makeRowVector({
1687+
makeFlatVector<int32_t>({1, 2, 3, 4}),
1688+
makeFlatVector<int64_t>({10, 20, 30, 40}),
1689+
makeFlatVector<std::string>({"a", "b", "c", "d"}),
1690+
});
1691+
1692+
runPartitionAndSerializeSerdeTest(data, 2, {{"c2", "c0"}});
1693+
}
1694+
15861695
} // namespace facebook::presto::operators::test
15871696

15881697
int main(int argc, char** argv) {
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
/*
15+
* Copyright (c) Facebook, Inc. and its affiliates.
16+
*/
17+
#include "presto_cpp/main/tool/trace/PartitionAndSerializeReplayer.h"
18+
19+
#include "presto_cpp/main/operators/PartitionAndSerialize.h"
20+
#include "velox/tool/trace/TraceReplayTaskRunner.h"
21+
22+
using namespace facebook::velox;
23+
using namespace facebook::velox::exec;
24+
using namespace facebook::velox::exec::test;
25+
using namespace facebook::presto;
26+
27+
namespace facebook::velox::tool::trace {
28+
29+
PartitionAndSerializeReplayer::PartitionAndSerializeReplayer(
30+
const std::string& traceDir,
31+
const std::string& queryId,
32+
const std::string& taskId,
33+
const std::string& nodeId,
34+
const std::string& nodeName,
35+
const std::string& driverIds,
36+
uint64_t queryCapacity,
37+
folly::Executor* executor)
38+
: OperatorReplayerBase(
39+
traceDir,
40+
queryId,
41+
taskId,
42+
nodeId,
43+
nodeName,
44+
driverIds,
45+
queryCapacity,
46+
executor) {}
47+
48+
RowVectorPtr PartitionAndSerializeReplayer::run(bool copyResults) {
49+
TraceReplayTaskRunner traceTaskRunner(createPlan(), createQueryCtx());
50+
auto [task, result] =
51+
traceTaskRunner.maxDrivers(driverIds_.size()).run(copyResults);
52+
printStats(task);
53+
return result;
54+
}
55+
56+
core::PlanNodePtr PartitionAndSerializeReplayer::createPlanNode(
57+
const core::PlanNode* node,
58+
const core::PlanNodeId& nodeId,
59+
const core::PlanNodePtr& source) const {
60+
const auto partitionAndSerializeNode =
61+
dynamic_cast<const presto::operators::PartitionAndSerializeNode*>(node);
62+
VELOX_CHECK_NOT_NULL(partitionAndSerializeNode);
63+
64+
return std::make_shared<presto::operators::PartitionAndSerializeNode>(
65+
nodeId,
66+
partitionAndSerializeNode->keys(),
67+
partitionAndSerializeNode->numPartitions(),
68+
partitionAndSerializeNode->serializedRowType(),
69+
source,
70+
partitionAndSerializeNode->isReplicateNullsAndAny(),
71+
partitionAndSerializeNode->partitionFunctionFactory(),
72+
partitionAndSerializeNode->sortingOrders(),
73+
partitionAndSerializeNode->sortingKeys());
74+
}
75+
76+
} // namespace facebook::velox::tool::trace
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
/*
15+
* Copyright (c) Facebook, Inc. and its affiliates.
16+
*/
17+
#pragma once
18+
19+
#include "velox/core/PlanNode.h"
20+
#include "velox/tool/trace/OperatorReplayerBase.h"
21+
22+
namespace facebook::velox::tool::trace {
23+
24+
/// The replayer to replay the traced 'PartitionAndSerialize' operators.
25+
class PartitionAndSerializeReplayer final : public OperatorReplayerBase {
26+
public:
27+
PartitionAndSerializeReplayer(
28+
const std::string& traceDir,
29+
const std::string& queryId,
30+
const std::string& taskId,
31+
const std::string& nodeId,
32+
const std::string& nodeName,
33+
const std::string& driverIds,
34+
uint64_t queryCapacity,
35+
folly::Executor* executor);
36+
37+
RowVectorPtr run(bool copyResults = true) override;
38+
39+
private:
40+
core::PlanNodePtr createPlanNode(
41+
const core::PlanNode* node,
42+
const core::PlanNodeId& nodeId,
43+
const core::PlanNodePtr& source) const override;
44+
};
45+
46+
} // namespace facebook::velox::tool::trace
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
#include "velox/tool/trace/TraceReplayRunner.h"
16+
17+
#include <folly/init/Init.h>
18+
#include "presto_cpp/main/operators/PartitionAndSerialize.h"
19+
#include "presto_cpp/main/tool/trace/PartitionAndSerializeReplayer.h"
20+
#include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h"
21+
22+
using namespace facebook::velox;
23+
using namespace facebook::presto;
24+
25+
namespace {
26+
/// Custom trace replay runner for Presto operators.
27+
/// This runner extends the base Velox TraceReplayRunner to support:
28+
/// - Presto-specific operators (e.g., PartitionAndSerialize)
29+
/// - Presto plan node serialization/deserialization
30+
class PrestoTraceReplayRunner
31+
: public facebook::velox::tool::trace::TraceReplayRunner {
32+
public:
33+
void init() override {
34+
// Register Presto plan node SerDe for reading traced plan nodes
35+
registerPrestoPlanNodeSerDe();
36+
37+
// Register custom Presto operators to execute during replay
38+
exec::Operator::registerOperator(
39+
std::make_unique<operators::PartitionAndSerializeTranslator>());
40+
41+
// Call base init to complete initialization
42+
TraceReplayRunner::init();
43+
}
44+
45+
private:
46+
std::unique_ptr<tool::trace::OperatorReplayerBase> createReplayer()
47+
const override {
48+
const auto nodeName = taskTraceMetadataReader_->nodeName(FLAGS_node_id);
49+
const auto queryCapacityBytes = (1ULL * FLAGS_query_memory_capacity_mb)
50+
<< 20;
51+
52+
if (nodeName == "PartitionAndSerialize") {
53+
return std::make_unique<tool::trace::PartitionAndSerializeReplayer>(
54+
FLAGS_root_dir,
55+
FLAGS_query_id,
56+
FLAGS_task_id,
57+
FLAGS_node_id,
58+
nodeName,
59+
FLAGS_driver_ids,
60+
queryCapacityBytes,
61+
cpuExecutor_.get());
62+
}
63+
64+
// Fall back to base class for standard Velox operators
65+
return TraceReplayRunner::createReplayer();
66+
}
67+
};
68+
} // namespace
69+
70+
int main(int argc, char** argv) {
71+
folly::Init init(&argc, &argv);
72+
PrestoTraceReplayRunner runner;
73+
runner.init();
74+
runner.run();
75+
return 0;
76+
}

0 commit comments

Comments
 (0)