|
26 | 26 | #include "velox/exec/tests/utils/PlanBuilder.h" |
27 | 27 | #include "velox/exec/tests/utils/TempDirectoryPath.h" |
28 | 28 | #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" |
| 29 | +#include "velox/functions/sparksql/aggregates/Register.h" |
29 | 30 |
|
30 | 31 | using namespace facebook::velox::exec::test; |
31 | 32 |
|
@@ -1036,5 +1037,67 @@ DEBUG_ONLY_TEST_F(WindowTest, releaseWindowBuildInTime) { |
1036 | 1037 | "ORDER BY d"); |
1037 | 1038 | } |
1038 | 1039 |
|
| 1040 | +class SparkWindowTest : public WindowTest { |
| 1041 | + public: |
| 1042 | + void SetUp() override { |
| 1043 | + WindowTest::SetUp(); |
| 1044 | + functions::aggregate::sparksql::registerAggregateFunctions(""); |
| 1045 | + } |
| 1046 | +}; |
| 1047 | + |
| 1048 | +DEBUG_ONLY_TEST_F(SparkWindowTest, destroyPreviousAccumulator) { |
| 1049 | + const auto size = 100; |
| 1050 | + auto input = makeRowVector( |
| 1051 | + {"d", "p0", "s"}, |
| 1052 | + { |
| 1053 | + // Payload Data. |
| 1054 | + makeFlatVector<std::string>(size, [](auto row){ return std::string(1024, 'a'); }), |
| 1055 | + // Partition key. |
| 1056 | + makeFlatVector<int64_t>(size, [](auto row) { return row % 11; }), |
| 1057 | + // Sorting key. |
| 1058 | + makeFlatVector<int32_t>(size, [](auto row) { return row; }), |
| 1059 | + }); |
| 1060 | + |
| 1061 | + createDuckDbTable({input}); |
| 1062 | + |
| 1063 | + core::PlanNodeId windowId; |
| 1064 | + core::PlanNodeId orderById; |
| 1065 | + auto plan = PlanBuilder() |
| 1066 | + .values(split(input, 10)) |
| 1067 | + .window({"last(d) over (partition by p0 order by s)"}) |
| 1068 | + .capturePlanNodeId(windowId) |
| 1069 | + .orderBy({"d"}, false) |
| 1070 | + .capturePlanNodeId(orderById) |
| 1071 | + .planNode(); |
| 1072 | + |
| 1073 | + const HashStringAllocator* stringAllocator = nullptr; |
| 1074 | + uint64_t usedBytes = 0; |
| 1075 | + SCOPED_TESTVALUE_SET( |
| 1076 | + "facebook::velox::exec::Driver::runInternal::getOutput", |
| 1077 | + std::function<void(Operator*)>([&](exec::Operator* op) { |
| 1078 | + auto* windowOp = dynamic_cast<exec::Window*>(op); |
| 1079 | + if (windowOp == nullptr) { |
| 1080 | + return; |
| 1081 | + } |
| 1082 | + if (stringAllocator == nullptr) { |
| 1083 | + stringAllocator = windowOp->testingGetHashStringAllocator(); |
| 1084 | + } else if (usedBytes == 0) { |
| 1085 | + // Record how many bytes have been used. |
| 1086 | + usedBytes = stringAllocator->currentBytes(); |
| 1087 | + } else { |
| 1088 | + // Because we will destroy previous created accumulator and every string in input |
| 1089 | + // is of the same length, so here we check if the `usedBytes` is not changed. |
| 1090 | + ASSERT_EQ(usedBytes, stringAllocator->currentBytes()); |
| 1091 | + } |
| 1092 | + })); |
| 1093 | + |
| 1094 | + auto task = |
| 1095 | + AssertQueryBuilder(plan, duckDbQueryRunner_) |
| 1096 | + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") |
| 1097 | + .assertResults( |
| 1098 | + "SELECT *, last(d) over (partition by p0 order by s) " |
| 1099 | + "FROM tmp "); |
| 1100 | +} |
| 1101 | + |
1039 | 1102 | } // namespace |
1040 | 1103 | } // namespace facebook::velox::exec |
0 commit comments