Skip to content

Commit cacd24a

Browse files
zhouyuanLakehouse Engine Bot
authored andcommitted
fix: Fix smj result mismatch issue in semi, anit and full outer join
Signed-off-by: Yuan <yuanzhou@apache.org> Alchemy-item: (ID = 1073) [OAP] [11771] Fix smj result mismatch issue commit 1/1 - 039d1ba
1 parent 247851a commit cacd24a

File tree

3 files changed

+152
-96
lines changed

3 files changed

+152
-96
lines changed

velox/exec/MergeJoin.cpp

Lines changed: 89 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ void MergeJoin::initialize() {
122122
isSemiFilterJoin(joinType_)) {
123123
joinTracker_ = JoinTracker(preferredOutputBatchRows_, pool());
124124
}
125-
} else if (joinNode_->isAntiJoin()) {
125+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
126126
// Anti join needs to track the left side rows that have no match on the
127127
// right.
128128
joinTracker_ = JoinTracker(preferredOutputBatchRows_, pool());
@@ -410,7 +410,8 @@ bool MergeJoin::tryAddOutputRow(
410410
const RowVectorPtr& leftBatch,
411411
vector_size_t leftRow,
412412
const RowVectorPtr& rightBatch,
413-
vector_size_t rightRow) {
413+
vector_size_t rightRow,
414+
bool isRightJoinForFullOuter) {
414415
if (outputSize_ == outputBatchSize_) {
415416
return false;
416417
}
@@ -444,12 +445,15 @@ bool MergeJoin::tryAddOutputRow(
444445
filterRightInputProjections_);
445446

446447
if (joinTracker_) {
447-
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
448+
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
449+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
448450
// Record right-side row with a match on the left-side.
449-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
451+
joinTracker_->addMatch(
452+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
450453
} else {
451454
// Record left-side row with a match on the right-side.
452-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
455+
joinTracker_->addMatch(
456+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
453457
}
454458
}
455459
}
@@ -459,7 +463,8 @@ bool MergeJoin::tryAddOutputRow(
459463
if (isAntiJoin(joinType_)) {
460464
VELOX_CHECK(joinTracker_.has_value());
461465
// Record left-side row with a match on the right-side.
462-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
466+
joinTracker_->addMatch(
467+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
463468
}
464469

465470
++outputSize_;
@@ -477,14 +482,15 @@ bool MergeJoin::prepareOutput(
477482
return true;
478483
}
479484

480-
if (isRightJoin(joinType_) && right != currentRight_) {
481-
return true;
482-
}
483-
484485
// If there is a new right, we need to flatten the dictionary.
485486
if (!isRightFlattened_ && right && currentRight_ != right) {
486487
flattenRightProjections();
487488
}
489+
490+
if (right != currentRight_) {
491+
return true;
492+
}
493+
488494
return false;
489495
}
490496

@@ -507,11 +513,15 @@ bool MergeJoin::prepareOutput(
507513
}
508514
} else {
509515
for (const auto& projection : leftProjections_) {
516+
auto column = left->childAt(projection.inputChannel);
517+
// Flatten the left column if the column already is DictionaryVector.
518+
if (column->wrappedVector()->encoding() ==
519+
VectorEncoding::Simple::DICTIONARY) {
520+
BaseVector::flattenVector(column);
521+
}
522+
column->clearContainingLazyAndWrapped();
510523
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
511-
{},
512-
leftOutputIndices_,
513-
outputBatchSize_,
514-
left->childAt(projection.inputChannel));
524+
{}, leftOutputIndices_, outputBatchSize_, column);
515525
}
516526
}
517527
currentLeft_ = left;
@@ -527,11 +537,10 @@ bool MergeJoin::prepareOutput(
527537
isRightFlattened_ = true;
528538
} else {
529539
for (const auto& projection : rightProjections_) {
540+
auto column = right->childAt(projection.inputChannel);
541+
column->clearContainingLazyAndWrapped();
530542
localColumns[projection.outputChannel] = BaseVector::wrapInDictionary(
531-
{},
532-
rightOutputIndices_,
533-
outputBatchSize_,
534-
right->childAt(projection.inputChannel));
543+
{}, rightOutputIndices_, outputBatchSize_, column);
535544
}
536545
isRightFlattened_ = false;
537546
}
@@ -595,6 +604,39 @@ bool MergeJoin::prepareOutput(
595604
bool MergeJoin::addToOutput() {
596605
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
597606
return addToOutputForRightJoin();
607+
} else if (isFullJoin(joinType_) && filter_) {
608+
if (!leftForRightJoinMatch_) {
609+
leftForRightJoinMatch_ = leftMatch_;
610+
rightForRightJoinMatch_ = rightMatch_;
611+
}
612+
613+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
614+
auto left = addToOutputForLeftJoin();
615+
if (!leftMatch_) {
616+
leftJoinForFullFinished_ = true;
617+
}
618+
if (left) {
619+
if (!leftMatch_) {
620+
leftMatch_ = leftForRightJoinMatch_;
621+
rightMatch_ = rightForRightJoinMatch_;
622+
}
623+
624+
return true;
625+
}
626+
}
627+
628+
if (!leftMatch_ && !rightJoinForFullFinished_) {
629+
leftMatch_ = leftForRightJoinMatch_;
630+
rightMatch_ = rightForRightJoinMatch_;
631+
rightJoinForFullFinished_ = true;
632+
}
633+
634+
auto right = addToOutputForRightJoin();
635+
636+
leftForRightJoinMatch_ = leftMatch_;
637+
rightForRightJoinMatch_ = rightMatch_;
638+
639+
return right;
598640
} else {
599641
return addToOutputForLeftJoin();
600642
}
@@ -687,7 +729,13 @@ bool MergeJoin::addToOutputImpl() {
687729
} else {
688730
for (auto innerRow = innerStartRow; innerRow < innerEndRow;
689731
++innerRow) {
690-
if (!tryAddOutputRow(leftBatch, innerRow, rightBatch, outerRow)) {
732+
const auto isRightJoinForFullOuter = isFullJoin(joinType_);
733+
if (!tryAddOutputRow(
734+
leftBatch,
735+
innerRow,
736+
rightBatch,
737+
outerRow,
738+
isRightJoinForFullOuter)) {
691739
outerMatch->setCursor(outerBatchIndex, outerRow);
692740
innerMatch->setCursor(innerBatchIndex, innerRow);
693741
return true;
@@ -959,7 +1007,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9591007
isFullJoin(joinType_)) {
9601008
// If output_ is currently wrapping a different buffer, return it
9611009
// first.
962-
if (prepareOutput(input_, nullptr)) {
1010+
if (prepareOutput(input_, rightInput_)) {
9631011
output_->resize(outputSize_);
9641012
return std::move(output_);
9651013
}
@@ -984,7 +1032,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9841032
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
9851033
// If output_ is currently wrapping a different buffer, return it
9861034
// first.
987-
if (prepareOutput(nullptr, rightInput_)) {
1035+
if (prepareOutput(input_, rightInput_)) {
9881036
output_->resize(outputSize_);
9891037
return std::move(output_);
9901038
}
@@ -1034,6 +1082,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10341082
matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex;
10351083
matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex;
10361084

1085+
leftJoinForFullFinished_ = false;
1086+
rightJoinForFullFinished_ = false;
10371087
if (!leftMatch_->complete || !rightMatch_->complete) {
10381088
if (!leftMatch_->complete) {
10391089
// Need to continue looking for the end of match.
@@ -1323,8 +1373,6 @@ void MergeJoin::updateOutputBatchSize(const RowVectorPtr& output) {
13231373
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13241374
const auto numRows = output->size();
13251375

1326-
RowVectorPtr fullOuterOutput = nullptr;
1327-
13281376
BufferPtr indices = allocateIndices(numRows, pool());
13291377
auto* rawIndices = indices->asMutable<vector_size_t>();
13301378
vector_size_t numPassed = 0;
@@ -1341,84 +1389,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13411389

13421390
// If all matches for a given left-side row fail the filter, add a row to
13431391
// the output with nulls for the right-side columns.
1344-
const auto onMiss = [&](auto row) {
1392+
const auto onMiss = [&](auto row, bool isRightJoinForFullOuter) {
13451393
if (isSemiFilterJoin(joinType_)) {
13461394
return;
13471395
}
13481396
rawIndices[numPassed++] = row;
13491397

1350-
if (isFullJoin(joinType_)) {
1351-
// For filtered rows, it is necessary to insert additional data
1352-
// to ensure the result set is complete. Specifically, we
1353-
// need to generate two records: one record containing the
1354-
// columns from the left table along with nulls for the
1355-
// right table, and another record containing the columns
1356-
// from the right table along with nulls for the left table.
1357-
// For instance, the current output is filtered based on the condition
1358-
// t > 1.
1359-
1360-
// 1, 1
1361-
// 2, 2
1362-
// 3, 3
1363-
1364-
// In this scenario, we need to additionally insert a record 1, 1.
1365-
// Subsequently, we will set the values of the columns on the left to
1366-
// null and the values of the columns on the right to null as well. By
1367-
// doing so, we will obtain the final result set.
1368-
1369-
// 1, null
1370-
// null, 1
1371-
// 2, 2
1372-
// 3, 3
1373-
fullOuterOutput = BaseVector::create<RowVector>(
1374-
output->type(), output->size() + 1, pool());
1375-
1376-
for (auto i = 0; i < row + 1; ++i) {
1377-
for (auto j = 0; j < output->type()->size(); ++j) {
1378-
fullOuterOutput->childAt(j)->copy(
1379-
output->childAt(j).get(), i, i, 1);
1398+
if (!isRightJoin(joinType_)) {
1399+
if (isFullJoin(joinType_) && isRightJoinForFullOuter) {
1400+
for (auto& projection : leftProjections_) {
1401+
auto target = output->childAt(projection.outputChannel);
1402+
target->setNull(row, true);
13801403
}
1381-
}
1382-
1383-
for (auto j = 0; j < output->type()->size(); ++j) {
1384-
fullOuterOutput->childAt(j)->copy(
1385-
output->childAt(j).get(), row + 1, row, 1);
1386-
}
1387-
1388-
for (auto i = row + 1; i < output->size(); ++i) {
1389-
for (auto j = 0; j < output->type()->size(); ++j) {
1390-
fullOuterOutput->childAt(j)->copy(
1391-
output->childAt(j).get(), i + 1, i, 1);
1404+
} else {
1405+
for (auto& projection : rightProjections_) {
1406+
auto target = output->childAt(projection.outputChannel);
1407+
target->setNull(row, true);
13921408
}
13931409
}
1394-
1395-
for (auto& projection : leftProjections_) {
1396-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1397-
target->setNull(row, true);
1398-
}
1399-
1400-
for (auto& projection : rightProjections_) {
1401-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1402-
target->setNull(row + 1, true);
1403-
}
1404-
} else if (!isRightJoin(joinType_)) {
1405-
for (auto& projection : rightProjections_) {
1406-
auto& target = output->childAt(projection.outputChannel);
1407-
target->setNull(row, true);
1408-
}
14091410
} else {
14101411
for (auto& projection : leftProjections_) {
1411-
auto& target = output->childAt(projection.outputChannel);
1412+
auto target = output->childAt(projection.outputChannel);
14121413
target->setNull(row, true);
14131414
}
14141415
}
14151416
};
14161417

14171418
auto onMatch = [&](auto row, bool firstMatch) {
1418-
const bool isNonSemiAntiJoin =
1419-
!isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_);
1419+
const bool isFullLeftJoin =
1420+
isFullJoin(joinType_) && !joinTracker_->isRightJoinForFullOuter(row);
1421+
1422+
const bool isNonSemiAntiFullJoin = !isSemiFilterJoin(joinType_) &&
1423+
!isAntiJoin(joinType_) && !isFullJoin(joinType_);
14201424

1421-
if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) {
1425+
if ((isSemiFilterJoin(joinType_) && firstMatch) ||
1426+
isNonSemiAntiFullJoin || isFullLeftJoin) {
14221427
rawIndices[numPassed++] = row;
14231428
}
14241429
};
@@ -1479,17 +1484,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14791484

14801485
if (numPassed == numRows) {
14811486
// All rows passed.
1482-
if (fullOuterOutput) {
1483-
return fullOuterOutput;
1484-
}
14851487
return output;
14861488
}
14871489

14881490
// Some, but not all rows passed.
1489-
if (fullOuterOutput) {
1490-
return wrap(numPassed, indices, fullOuterOutput);
1491-
}
1492-
14931491
return wrap(numPassed, indices, output);
14941492
}
14951493

0 commit comments

Comments
 (0)