Skip to content

Commit e8cbd29

Browse files
authored
Unified joins behavior (#1405)
1 parent 5705788 commit e8cbd29

File tree

10 files changed

+369
-31
lines changed

10 files changed

+369
-31
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
schema
22
|-- id: integer
3-
|-- name: string
3+
|-- name: ?string
44
|-- active: boolean

phpunit.xml.dist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
colors="true"
77
displayDetailsOnTestsThatTriggerWarnings="true"
88
displayDetailsOnTestsThatTriggerErrors="true"
9+
cacheDirectory="var/phpunit/cache"
910
>
1011
<php>
1112
<env name="AZURITE_HOST" value="localhost"/>

src/core/etl/src/Flow/ETL/Join/Expression.php

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
use Flow\ETL\Exception\RuntimeException;
88
use Flow\ETL\Join\Comparison\{All, Equal};
9-
use Flow\ETL\Row;
109
use Flow\ETL\Row\Reference;
10+
use Flow\ETL\{Row};
1111

1212
final readonly class Expression
1313
{
@@ -20,7 +20,7 @@ public function __construct(
2020
/**
2121
* @param array<Comparison>|array<string, string>|Comparison $comparison
2222
*/
23-
public static function on(array|Comparison $comparison, string $joinPrefix = 'joined_') : self
23+
public static function on(array|Comparison $comparison, string $joinPrefix = '') : self
2424
{
2525
if (\is_array($comparison)) {
2626
/** @var array<Comparison> $comparisons */
@@ -50,6 +50,62 @@ public static function on(array|Comparison $comparison, string $joinPrefix = 'jo
5050
return new self($comparison, $joinPrefix);
5151
}
5252

53+
public function dropDuplicateLeftEntries(Row $left) : Row
54+
{
55+
if ($this->joinPrefix === '') {
56+
$leftEntries = [];
57+
$rightEntries = [];
58+
59+
foreach ($this->left() as $leftReference) {
60+
$leftEntries[] = $leftReference->name();
61+
}
62+
63+
foreach ($this->right() as $rightReference) {
64+
$rightEntries[] = $rightReference->name();
65+
}
66+
67+
$dropLeft = [];
68+
69+
foreach ($leftEntries as $leftEntry) {
70+
if (\in_array($leftEntry, $rightEntries, true)) {
71+
$dropLeft[] = $leftEntry;
72+
}
73+
}
74+
75+
return $left->remove(...$dropLeft);
76+
}
77+
78+
return $left;
79+
}
80+
81+
public function dropDuplicateRightEntries(Row $right) : Row
82+
{
83+
if ($this->joinPrefix === '') {
84+
$leftEntries = [];
85+
$rightEntries = [];
86+
87+
foreach ($this->left() as $leftReference) {
88+
$leftEntries[] = $leftReference->name();
89+
}
90+
91+
foreach ($this->right() as $rightReference) {
92+
$rightEntries[] = $rightReference->name();
93+
}
94+
95+
$dropRight = [];
96+
97+
foreach ($rightEntries as $rightEntry) {
98+
if (\in_array($rightEntry, $leftEntries, true)) {
99+
$dropRight[] = $rightEntry;
100+
}
101+
}
102+
103+
return $right->remove(...$dropRight);
104+
}
105+
106+
return $right;
107+
}
108+
53109
/**
54110
* @return array<Reference>
55111
*/

src/core/etl/src/Flow/ETL/Pipeline/HashJoinPipeline.php

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,13 @@ public function source() : Extractor
122122
private function createRows(Row $leftRow, Row $rightRow) : Rows
123123
{
124124
try {
125-
return rows($leftRow->merge($rightRow, $this->expression->prefix()));
125+
return match ($this->join) {
126+
Join::inner => rows($leftRow->merge($rightRow, $this->expression->prefix())),
127+
Join::left => rows($leftRow->merge($this->expression->dropDuplicateRightEntries($rightRow), $this->expression->prefix())),
128+
Join::right => rows($this->expression->dropDuplicateLeftEntries($leftRow)->merge($rightRow, $this->expression->prefix())),
129+
Join::left_anti => rows(),
130+
};
131+
126132
} catch (DuplicatedEntriesException $e) {
127133
throw new JoinException($e->getMessage() . ' try to use a different join prefix than: "' . $this->expression->prefix() . '"', $e->getCode(), $e);
128134
}

src/core/etl/src/Flow/ETL/Rows.php

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ public function joinInner(self $right, Expression $expression) : self
328328
foreach ($right as $rightRow) {
329329
if ($expression->meet($leftRow, $rightRow)) {
330330
try {
331-
$joinedRow = $leftRow->merge($rightRow, $expression->prefix());
331+
$joinedRow = $leftRow->merge($expression->dropDuplicateRightEntries($rightRow), $expression->prefix());
332332
} catch (DuplicatedEntriesException $e) {
333333
throw new DuplicatedEntriesException($e->getMessage() . ' try to use a different join prefix than: "' . $expression->prefix() . '"');
334334
}
@@ -364,7 +364,7 @@ public function joinLeft(self $right, Expression $expression) : self
364364
foreach ($right as $rightRow) {
365365
if ($expression->meet($leftRow, $rightRow)) {
366366
try {
367-
$joinedRow = $leftRow->merge($rightRow, $expression->prefix());
367+
$joinedRow = $leftRow->merge($expression->dropDuplicateRightEntries($rightRow), $expression->prefix());
368368
} catch (DuplicatedEntriesException $e) {
369369
throw new DuplicatedEntriesException($e->getMessage() . ' try to use a different join prefix than: "' . $expression->prefix() . '"');
370370
}
@@ -382,7 +382,7 @@ public function joinLeft(self $right, Expression $expression) : self
382382
$entries[] = $entryFactory->create($definition->entry()->name(), null, $definition->makeNullable());
383383
}
384384

385-
$joinedRow = $leftRow->merge(row(...$entries), $expression->prefix());
385+
$joinedRow = $leftRow->merge($expression->dropDuplicateRightEntries(row(...$entries)), $expression->prefix());
386386
}
387387

388388
$joined[] = $joinedRow;
@@ -438,7 +438,7 @@ public function joinRight(self $right, Expression $expression) : self
438438
foreach ($this->rows as $leftRow) {
439439
if ($expression->meet($leftRow, $rightRow)) {
440440
try {
441-
$joinedRow = $leftRow->merge($rightRow, $expression->prefix());
441+
$joinedRow = $expression->dropDuplicateLeftEntries($leftRow)->merge($rightRow, $expression->prefix());
442442
} catch (DuplicatedEntriesException $e) {
443443
throw new DuplicatedEntriesException($e->getMessage() . ' try to use a different join prefix than: "' . $expression->prefix() . '"');
444444
}
@@ -456,7 +456,7 @@ public function joinRight(self $right, Expression $expression) : self
456456
$entries[] = $entryFactory->create($definition->entry()->name(), null, $definition->makeNullable());
457457
}
458458

459-
$joined[] = row(...$entries)->merge($rightRow, $expression->prefix());
459+
$joined[] = $expression->dropDuplicateLeftEntries(row(...$entries))->merge($rightRow, $expression->prefix());
460460
}
461461
}
462462

src/core/etl/src/Flow/ETL/Transformer/JoinEachRowsTransformer.php

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
{
1212
private function __construct(
1313
private DataFrameFactory $factory,
14-
private Expression $condition,
14+
private Expression $expression,
1515
private Join $type,
1616
) {
1717
}
@@ -43,11 +43,13 @@ public static function right(DataFrameFactory $right, Expression $condition) : s
4343
*/
4444
public function transform(Rows $rows, FlowContext $context) : Rows
4545
{
46+
$rightRows = $this->factory->from($rows)->fetch();
47+
4648
return match ($this->type) {
47-
Join::left => $rows->joinLeft($this->factory->from($rows)->fetch(), $this->condition),
48-
Join::left_anti => $rows->joinLeftAnti($this->factory->from($rows)->fetch(), $this->condition),
49-
Join::right => $rows->joinRight($this->factory->from($rows)->fetch(), $this->condition),
50-
default => $rows->joinInner($this->factory->from($rows)->fetch(), $this->condition),
49+
Join::left => $rows->joinLeft($rightRows, $this->expression),
50+
Join::left_anti => $rows->joinLeftAnti($rightRows, $this->expression),
51+
Join::right => $rows->joinRight($rightRows, $this->expression),
52+
default => $rows->joinInner($rightRows, $this->expression),
5153
};
5254
}
5355
}

src/core/etl/tests/Flow/ETL/Tests/Integration/DataFrame/JoinEachTest.php

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace Flow\ETL\Tests\Integration\DataFrame;
66

77
use function Flow\ETL\DSL\data_frame;
8-
use function Flow\ETL\DSL\{df, from_rows, int_entry, row, rows, str_entry};
8+
use function Flow\ETL\DSL\{df, from_rows, int_entry, join_on, row, rows, str_entry};
99
use Flow\ETL\Join\Expression;
1010
use Flow\ETL\{DataFrame, DataFrameFactory, Loader, Rows, Tests\FlowTestCase};
1111

@@ -43,7 +43,7 @@ public function from(Rows $rows) : DataFrame
4343
);
4444
}
4545
},
46-
Expression::on(['country' => 'code']),
46+
Expression::on(['country' => 'code'], 'joined_'),
4747
)
4848
->write($loader)
4949
->fetch();
@@ -62,4 +62,56 @@ public function from(Rows $rows) : DataFrame
6262
$rows->toArray()
6363
);
6464
}
65+
66+
public function test_join_each_without_prefix() : void
67+
{
68+
$loader = $this->createMock(Loader::class);
69+
$loader->expects(self::exactly(2))
70+
->method('load');
71+
72+
$rows = df()
73+
->read(from_rows(
74+
rows(
75+
row(int_entry('id', 1), str_entry('country_code', 'PL')),
76+
row(int_entry('id', 2), str_entry('country_code', 'PL')),
77+
row(int_entry('id', 3), str_entry('country_code', 'PL')),
78+
row(int_entry('id', 4), str_entry('country_code', 'PL')),
79+
row(int_entry('id', 5), str_entry('country_code', 'US')),
80+
row(int_entry('id', 6), str_entry('country_code', 'US')),
81+
row(int_entry('id', 7), str_entry('country_code', 'US')),
82+
row(int_entry('id', 9), str_entry('country_code', 'US')),
83+
)
84+
))
85+
->batchSize(4)
86+
->joinEach(
87+
new class implements DataFrameFactory {
88+
public function from(Rows $rows) : DataFrame
89+
{
90+
return data_frame()->process(
91+
rows(
92+
row(str_entry('country_code', 'PL'), str_entry('name', 'Poland')),
93+
row(str_entry('country_code', 'US'), str_entry('name', 'United States')),
94+
)
95+
);
96+
}
97+
},
98+
join_on(['country_code' => 'country_code']),
99+
)
100+
->write($loader)
101+
->fetch();
102+
103+
self::assertEquals(
104+
[
105+
['id' => 1, 'country_code' => 'PL', 'name' => 'Poland'],
106+
['id' => 2, 'country_code' => 'PL', 'name' => 'Poland'],
107+
['id' => 3, 'country_code' => 'PL', 'name' => 'Poland'],
108+
['id' => 4, 'country_code' => 'PL', 'name' => 'Poland'],
109+
['id' => 5, 'country_code' => 'US', 'name' => 'United States'],
110+
['id' => 6, 'country_code' => 'US', 'name' => 'United States'],
111+
['id' => 7, 'country_code' => 'US', 'name' => 'United States'],
112+
['id' => 9, 'country_code' => 'US', 'name' => 'United States'],
113+
],
114+
$rows->toArray()
115+
);
116+
}
65117
}

0 commit comments

Comments
 (0)