@@ -122,7 +122,7 @@ void MergeJoin::initialize() {
122122 isSemiFilterJoin (joinType_)) {
123123 joinTracker_ = JoinTracker (preferredOutputBatchRows_, pool ());
124124 }
125- } else if (joinNode_->isAntiJoin ()) {
125+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
126126 // Anti join needs to track the left side rows that have no match on the
127127 // right.
128128 joinTracker_ = JoinTracker (preferredOutputBatchRows_, pool ());
@@ -410,7 +410,8 @@ bool MergeJoin::tryAddOutputRow(
410410 const RowVectorPtr& leftBatch,
411411 vector_size_t leftRow,
412412 const RowVectorPtr& rightBatch,
413- vector_size_t rightRow) {
413+ vector_size_t rightRow,
414+ bool isRightJoinForFullOuter) {
414415 if (outputSize_ == outputBatchSize_) {
415416 return false ;
416417 }
@@ -444,12 +445,15 @@ bool MergeJoin::tryAddOutputRow(
444445 filterRightInputProjections_);
445446
446447 if (joinTracker_) {
447- if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
448+ if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_) ||
449+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
448450 // Record right-side row with a match on the left-side.
449- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
451+ joinTracker_->addMatch (
452+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
450453 } else {
451454 // Record left-side row with a match on the right-side.
452- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
455+ joinTracker_->addMatch (
456+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
453457 }
454458 }
455459 }
@@ -459,7 +463,8 @@ bool MergeJoin::tryAddOutputRow(
459463 if (isAntiJoin (joinType_)) {
460464 VELOX_CHECK (joinTracker_.has_value ());
461465 // Record left-side row with a match on the right-side.
462- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
466+ joinTracker_->addMatch (
467+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
463468 }
464469
465470 ++outputSize_;
@@ -477,14 +482,15 @@ bool MergeJoin::prepareOutput(
477482 return true ;
478483 }
479484
480- if (isRightJoin (joinType_) && right != currentRight_) {
481- return true ;
482- }
483-
484485 // If there is a new right, we need to flatten the dictionary.
485486 if (!isRightFlattened_ && right && currentRight_ != right) {
486487 flattenRightProjections ();
487488 }
489+
490+ if (right != currentRight_) {
491+ return true ;
492+ }
493+
488494 return false ;
489495 }
490496
@@ -507,11 +513,15 @@ bool MergeJoin::prepareOutput(
507513 }
508514 } else {
509515 for (const auto & projection : leftProjections_) {
516+ auto column = left->childAt (projection.inputChannel );
517+ // Flatten the left column if the column already is DictionaryVector.
518+ if (column->wrappedVector ()->encoding () ==
519+ VectorEncoding::Simple::DICTIONARY) {
520+ BaseVector::flattenVector (column);
521+ }
522+ column->clearContainingLazyAndWrapped ();
510523 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
511- {},
512- leftOutputIndices_,
513- outputBatchSize_,
514- left->childAt (projection.inputChannel ));
524+ {}, leftOutputIndices_, outputBatchSize_, column);
515525 }
516526 }
517527 currentLeft_ = left;
@@ -527,11 +537,10 @@ bool MergeJoin::prepareOutput(
527537 isRightFlattened_ = true ;
528538 } else {
529539 for (const auto & projection : rightProjections_) {
540+ auto column = right->childAt (projection.inputChannel );
541+ column->clearContainingLazyAndWrapped ();
530542 localColumns[projection.outputChannel ] = BaseVector::wrapInDictionary (
531- {},
532- rightOutputIndices_,
533- outputBatchSize_,
534- right->childAt (projection.inputChannel ));
543+ {}, rightOutputIndices_, outputBatchSize_, column);
535544 }
536545 isRightFlattened_ = false ;
537546 }
@@ -595,6 +604,39 @@ bool MergeJoin::prepareOutput(
595604bool MergeJoin::addToOutput () {
596605 if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
597606 return addToOutputForRightJoin ();
607+ } else if (isFullJoin (joinType_) && filter_) {
608+ if (!leftForRightJoinMatch_) {
609+ leftForRightJoinMatch_ = leftMatch_;
610+ rightForRightJoinMatch_ = rightMatch_;
611+ }
612+
613+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
614+ auto left = addToOutputForLeftJoin ();
615+ if (!leftMatch_) {
616+ leftJoinForFullFinished_ = true ;
617+ }
618+ if (left) {
619+ if (!leftMatch_) {
620+ leftMatch_ = leftForRightJoinMatch_;
621+ rightMatch_ = rightForRightJoinMatch_;
622+ }
623+
624+ return true ;
625+ }
626+ }
627+
628+ if (!leftMatch_ && !rightJoinForFullFinished_) {
629+ leftMatch_ = leftForRightJoinMatch_;
630+ rightMatch_ = rightForRightJoinMatch_;
631+ rightJoinForFullFinished_ = true ;
632+ }
633+
634+ auto right = addToOutputForRightJoin ();
635+
636+ leftForRightJoinMatch_ = leftMatch_;
637+ rightForRightJoinMatch_ = rightMatch_;
638+
639+ return right;
598640 } else {
599641 return addToOutputForLeftJoin ();
600642 }
@@ -687,7 +729,13 @@ bool MergeJoin::addToOutputImpl() {
687729 } else {
688730 for (auto innerRow = innerStartRow; innerRow < innerEndRow;
689731 ++innerRow) {
690- if (!tryAddOutputRow (leftBatch, innerRow, rightBatch, outerRow)) {
732+ const auto isRightJoinForFullOuter = isFullJoin (joinType_);
733+ if (!tryAddOutputRow (
734+ leftBatch,
735+ innerRow,
736+ rightBatch,
737+ outerRow,
738+ isRightJoinForFullOuter)) {
691739 outerMatch->setCursor (outerBatchIndex, outerRow);
692740 innerMatch->setCursor (innerBatchIndex, innerRow);
693741 return true ;
@@ -959,7 +1007,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9591007 isFullJoin (joinType_)) {
9601008 // If output_ is currently wrapping a different buffer, return it
9611009 // first.
962- if (prepareOutput (input_, nullptr )) {
1010+ if (prepareOutput (input_, rightInput_ )) {
9631011 output_->resize (outputSize_);
9641012 return std::move (output_);
9651013 }
@@ -984,7 +1032,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
9841032 if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
9851033 // If output_ is currently wrapping a different buffer, return it
9861034 // first.
987- if (prepareOutput (nullptr , rightInput_)) {
1035+ if (prepareOutput (input_ , rightInput_)) {
9881036 output_->resize (outputSize_);
9891037 return std::move (output_);
9901038 }
@@ -1034,6 +1082,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
10341082 matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex ;
10351083 matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex ;
10361084
1085+ leftJoinForFullFinished_ = false ;
1086+ rightJoinForFullFinished_ = false ;
10371087 if (!leftMatch_->complete || !rightMatch_->complete ) {
10381088 if (!leftMatch_->complete ) {
10391089 // Need to continue looking for the end of match.
@@ -1323,8 +1373,6 @@ void MergeJoin::updateOutputBatchSize(const RowVectorPtr& output) {
13231373RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
13241374 const auto numRows = output->size ();
13251375
1326- RowVectorPtr fullOuterOutput = nullptr ;
1327-
13281376 BufferPtr indices = allocateIndices (numRows, pool ());
13291377 auto * rawIndices = indices->asMutable <vector_size_t >();
13301378 vector_size_t numPassed = 0 ;
@@ -1341,84 +1389,41 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13411389
13421390 // If all matches for a given left-side row fail the filter, add a row to
13431391 // the output with nulls for the right-side columns.
1344- const auto onMiss = [&](auto row) {
1392+ const auto onMiss = [&](auto row, bool isRightJoinForFullOuter ) {
13451393 if (isSemiFilterJoin (joinType_)) {
13461394 return ;
13471395 }
13481396 rawIndices[numPassed++] = row;
13491397
1350- if (isFullJoin (joinType_)) {
1351- // For filtered rows, it is necessary to insert additional data
1352- // to ensure the result set is complete. Specifically, we
1353- // need to generate two records: one record containing the
1354- // columns from the left table along with nulls for the
1355- // right table, and another record containing the columns
1356- // from the right table along with nulls for the left table.
1357- // For instance, the current output is filtered based on the condition
1358- // t > 1.
1359-
1360- // 1, 1
1361- // 2, 2
1362- // 3, 3
1363-
1364- // In this scenario, we need to additionally insert a record 1, 1.
1365- // Subsequently, we will set the values of the columns on the left to
1366- // null and the values of the columns on the right to null as well. By
1367- // doing so, we will obtain the final result set.
1368-
1369- // 1, null
1370- // null, 1
1371- // 2, 2
1372- // 3, 3
1373- fullOuterOutput = BaseVector::create<RowVector>(
1374- output->type (), output->size () + 1 , pool ());
1375-
1376- for (auto i = 0 ; i < row + 1 ; ++i) {
1377- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1378- fullOuterOutput->childAt (j)->copy (
1379- output->childAt (j).get (), i, i, 1 );
1398+ if (!isRightJoin (joinType_)) {
1399+ if (isFullJoin (joinType_) && isRightJoinForFullOuter) {
1400+ for (auto & projection : leftProjections_) {
1401+ auto target = output->childAt (projection.outputChannel );
1402+ target->setNull (row, true );
13801403 }
1381- }
1382-
1383- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1384- fullOuterOutput->childAt (j)->copy (
1385- output->childAt (j).get (), row + 1 , row, 1 );
1386- }
1387-
1388- for (auto i = row + 1 ; i < output->size (); ++i) {
1389- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1390- fullOuterOutput->childAt (j)->copy (
1391- output->childAt (j).get (), i + 1 , i, 1 );
1404+ } else {
1405+ for (auto & projection : rightProjections_) {
1406+ auto target = output->childAt (projection.outputChannel );
1407+ target->setNull (row, true );
13921408 }
13931409 }
1394-
1395- for (auto & projection : leftProjections_) {
1396- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1397- target->setNull (row, true );
1398- }
1399-
1400- for (auto & projection : rightProjections_) {
1401- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1402- target->setNull (row + 1 , true );
1403- }
1404- } else if (!isRightJoin (joinType_)) {
1405- for (auto & projection : rightProjections_) {
1406- auto & target = output->childAt (projection.outputChannel );
1407- target->setNull (row, true );
1408- }
14091410 } else {
14101411 for (auto & projection : leftProjections_) {
1411- auto & target = output->childAt (projection.outputChannel );
1412+ auto target = output->childAt (projection.outputChannel );
14121413 target->setNull (row, true );
14131414 }
14141415 }
14151416 };
14161417
14171418 auto onMatch = [&](auto row, bool firstMatch) {
1418- const bool isNonSemiAntiJoin =
1419- !isSemiFilterJoin (joinType_) && !isAntiJoin (joinType_);
1419+ const bool isFullLeftJoin =
1420+ isFullJoin (joinType_) && !joinTracker_->isRightJoinForFullOuter (row);
1421+
1422+ const bool isNonSemiAntiFullJoin = !isSemiFilterJoin (joinType_) &&
1423+ !isAntiJoin (joinType_) && !isFullJoin (joinType_);
14201424
1421- if ((isSemiFilterJoin (joinType_) && firstMatch) || isNonSemiAntiJoin) {
1425+ if ((isSemiFilterJoin (joinType_) && firstMatch) ||
1426+ isNonSemiAntiFullJoin || isFullLeftJoin) {
14221427 rawIndices[numPassed++] = row;
14231428 }
14241429 };
@@ -1479,17 +1484,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
14791484
14801485 if (numPassed == numRows) {
14811486 // All rows passed.
1482- if (fullOuterOutput) {
1483- return fullOuterOutput;
1484- }
14851487 return output;
14861488 }
14871489
14881490 // Some, but not all rows passed.
1489- if (fullOuterOutput) {
1490- return wrap (numPassed, indices, fullOuterOutput);
1491- }
1492-
14931491 return wrap (numPassed, indices, output);
14941492}
14951493
0 commit comments