Skip to content

Commit 3dc96f3

Browse files
committed
fix join deduplicate and tests
handle deduplication for right/full joins by coalescing join keys refactor join preparation to lower complexity update tests to use supported sort API and full join keyword fix lint issues
1 parent 0bb81df commit 3dc96f3

File tree

2 files changed

+82
-70
lines changed

2 files changed

+82
-70
lines changed

python/datafusion/dataframe.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from datafusion.plan import ExecutionPlan, LogicalPlan
4747
from datafusion.record_batch import RecordBatchStream
4848

49+
from .functions import coalesce, col
50+
4951
if TYPE_CHECKING:
5052
import pathlib
5153
from typing import Callable, Sequence
@@ -77,6 +79,31 @@ class JoinPreparation:
7779
drop_cols: list[str]
7880

7981

82+
def _deduplicate_right(
83+
right: DataFrame, columns: Sequence[str]
84+
) -> tuple[DataFrame, list[str]]:
85+
"""Rename join columns on the right DataFrame for deduplication."""
86+
existing_columns = set(right.schema().names)
87+
modified = right
88+
aliases: list[str] = []
89+
90+
for col_name in columns:
91+
base_alias = f"__right_{col_name}"
92+
alias = base_alias
93+
counter = 0
94+
while alias in existing_columns:
95+
counter += 1
96+
alias = f"{base_alias}_{counter}"
97+
if alias in existing_columns:
98+
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"
99+
100+
modified = modified.with_column_renamed(col_name, alias)
101+
aliases.append(alias)
102+
existing_columns.add(alias)
103+
104+
return modified, aliases
105+
106+
80107
# excerpt from deltalake
81108
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
82109
class Compression(Enum):
@@ -730,10 +757,23 @@ def join(
730757
join_preparation.join_keys.right_names,
731758
)
732759
)
733-
760+
761+
if (
762+
deduplicate
763+
and how in ("right", "full")
764+
and join_preparation.join_keys.on is not None
765+
):
766+
for left_name, right_alias in zip(
767+
join_preparation.join_keys.left_names,
768+
join_preparation.drop_cols,
769+
):
770+
result = result.with_column(
771+
left_name, coalesce(col(left_name), col(right_alias))
772+
)
773+
734774
if join_preparation.drop_cols:
735775
result = result.drop(*join_preparation.drop_cols)
736-
776+
737777
return result
738778

739779
def _prepare_join(
@@ -746,18 +786,18 @@ def _prepare_join(
746786
deduplicate: bool,
747787
) -> JoinPreparation:
748788
"""Prepare join keys and handle deduplication if requested.
749-
789+
750790
This method combines join key resolution and deduplication preparation
751791
to avoid parameter handling duplication and provide a unified interface.
752-
792+
753793
Args:
754794
right: The right DataFrame to join with.
755795
on: Column names to join on in both dataframes.
756796
left_on: Join column of the left dataframe.
757797
right_on: Join column of the right dataframe.
758798
join_keys: Tuple of two lists of column names to join on. [Deprecated]
759799
deduplicate: If True, prepare right DataFrame for column deduplication.
760-
800+
761801
Returns:
762802
JoinPreparation containing resolved join keys, modified right DataFrame,
763803
and columns to drop after joining.
@@ -787,71 +827,41 @@ def _prepare_join(
787827

788828
if resolved_on is not None:
789829
if left_on is not None or right_on is not None:
790-
error_msg = (
791-
"`left_on` or `right_on` should not be provided with `on`. "
792-
"Note: `deduplicate` must be specified as a keyword argument."
793-
)
830+
error_msg = "`left_on` or `right_on` should not provided with `on`"
794831
raise ValueError(error_msg)
795832
left_on = resolved_on
796833
right_on = resolved_on
797834
elif left_on is not None or right_on is not None:
798835
if left_on is None or right_on is None:
799-
error_msg = (
800-
"`left_on` and `right_on` should both be provided. "
801-
"Note: `deduplicate` must be specified as a keyword argument."
802-
)
836+
error_msg = "`left_on` and `right_on` should both be provided."
803837
raise ValueError(error_msg)
804838
else:
805-
error_msg = (
806-
"Either `on` or both `left_on` and `right_on` should be provided. "
807-
"Note: `deduplicate` must be specified as a keyword argument."
808-
)
839+
error_msg = "either `on` or `left_on` and `right_on` should be provided."
809840
raise ValueError(error_msg)
810841

811842
# At this point, left_on and right_on are guaranteed to be non-None
812-
assert left_on is not None and right_on is not None
813-
843+
if left_on is None or right_on is None: # pragma: no cover - sanity check
844+
msg = "join keys resolved to None"
845+
raise ValueError(msg)
846+
814847
left_names = [left_on] if isinstance(left_on, str) else list(left_on)
815848
right_names = [right_on] if isinstance(right_on, str) else list(right_on)
816-
817-
join_keys_resolved = JoinKeys(
818-
on=resolved_on, left_names=left_names, right_names=right_names
819-
)
820-
821-
# Step 2: Handle deduplication if requested
849+
822850
drop_cols: list[str] = []
823851
modified_right = right
824-
852+
825853
if deduplicate and resolved_on is not None:
826-
# Prepare deduplication by renaming columns in the right DataFrame
827-
on_cols = [resolved_on] if isinstance(resolved_on, str) else list(resolved_on)
828-
829-
# Get existing column names to avoid collisions
830-
existing_columns = set(right.schema().names)
831-
832-
for col_name in on_cols:
833-
# Generate a collision-safe temporary alias
834-
base_alias = f"__right_{col_name}"
835-
alias = base_alias
836-
counter = 0
837-
838-
# Keep trying until we find a unique name
839-
while alias in existing_columns:
840-
counter += 1
841-
alias = f"{base_alias}_{counter}"
842-
843-
# If even that fails (very unlikely), use UUID
844-
if alias in existing_columns:
845-
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"
846-
847-
modified_right = modified_right.with_column_renamed(col_name, alias)
848-
drop_cols.append(alias)
849-
# Add the new alias to existing columns to avoid future collisions
850-
existing_columns.add(alias)
851-
852-
# Update right_names to use the new aliases
853-
right_names = drop_cols.copy()
854-
854+
on_cols = (
855+
[resolved_on] if isinstance(resolved_on, str) else list(resolved_on)
856+
)
857+
modified_right, aliases = _deduplicate_right(right, on_cols)
858+
drop_cols.extend(aliases)
859+
right_names = aliases.copy()
860+
861+
join_keys_resolved = JoinKeys(
862+
on=resolved_on, left_names=left_names, right_names=right_names
863+
)
864+
855865
return JoinPreparation(
856866
join_keys=join_keys_resolved,
857867
modified_right=modified_right,

python/tests/test_dataframe.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def test_join_deduplicate_multi():
558558
right = ctx.create_dataframe([[batch]], "r")
559559

560560
joined = left.join(right, on=["a", "b"], deduplicate=True)
561-
joined = joined.sort([column("a"), column("b")])
561+
joined = joined.sort(column("a"), column("b"))
562562
table = pa.Table.from_batches(joined.collect())
563563

564564
expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]}
@@ -2678,7 +2678,9 @@ def test_join_deduplicate_select():
26782678

26792679
# Ensure no internal alias names like "__right_id" appear in the schema
26802680
for col_name in column_names:
2681-
assert not col_name.startswith("__"), f"Internal alias '{col_name}' leaked into schema"
2681+
assert not col_name.startswith("__"), (
2682+
f"Internal alias '{col_name}' leaked into schema"
2683+
)
26822684

26832685
# Test selecting each column individually to ensure they all work
26842686
for col_name in expected_columns:
@@ -2693,13 +2695,13 @@ def test_join_deduplicate_select():
26932695
assert all_result.schema.names == expected_columns
26942696

26952697
# Verify that attempting to select a potential internal alias fails appropriately
2696-
with pytest.raises(Exception): # Should raise an error for non-existent column
2698+
with pytest.raises(Exception): # noqa: B017 - generic exception from FFI
26972699
joined_df.select(column("__right_id")).collect()
26982700

26992701

27002702
def test_join_deduplicate_all_types():
27012703
"""Test deduplication behavior across different join types (left, right, outer).
2702-
2704+
27032705
Note: This test may show linting errors due to method signature overloads,
27042706
but the functionality should work correctly at runtime.
27052707
"""
@@ -2721,8 +2723,8 @@ def test_join_deduplicate_all_types():
27212723

27222724
# Test inner join with deduplication (default behavior)
27232725
inner_joined = left_df.join(right_df, on="id", how="inner", deduplicate=True)
2724-
inner_result = inner_joined.sort([column("id")]).collect()[0]
2725-
2726+
inner_result = inner_joined.sort(column("id")).collect()[0]
2727+
27262728
# Should only have matching rows (2, 3)
27272729
expected_inner = {
27282730
"id": [2, 3],
@@ -2733,8 +2735,8 @@ def test_join_deduplicate_all_types():
27332735

27342736
# Test left join with deduplication
27352737
left_joined = left_df.join(right_df, on="id", how="left", deduplicate=True)
2736-
left_result = left_joined.sort([column("id")]).collect()[0]
2737-
2738+
left_result = left_joined.sort(column("id")).collect()[0]
2739+
27382740
# Should have all left rows, with nulls for unmatched right rows
27392741
expected_left = {
27402742
"id": [1, 2, 3, 4],
@@ -2745,8 +2747,8 @@ def test_join_deduplicate_all_types():
27452747

27462748
# Test right join with deduplication
27472749
right_joined = left_df.join(right_df, on="id", how="right", deduplicate=True)
2748-
right_result = right_joined.sort([column("id")]).collect()[0]
2749-
2750+
right_result = right_joined.sort(column("id")).collect()[0]
2751+
27502752
# Should have all right rows, with nulls for unmatched left rows
27512753
expected_right = {
27522754
"id": [2, 3, 5, 6],
@@ -2756,9 +2758,9 @@ def test_join_deduplicate_all_types():
27562758
assert right_result.to_pydict() == expected_right
27572759

27582760
# Test full outer join with deduplication
2759-
outer_joined = left_df.join(right_df, on="id", how="outer", deduplicate=True)
2760-
outer_result = outer_joined.sort([column("id")]).collect()[0]
2761-
2761+
outer_joined = left_df.join(right_df, on="id", how="full", deduplicate=True)
2762+
outer_result = outer_joined.sort(column("id")).collect()[0]
2763+
27622764
# Should have all rows from both sides, with nulls for unmatched rows
27632765
expected_outer = {
27642766
"id": [1, 2, 3, 4, 5, 6],
@@ -2768,8 +2770,8 @@ def test_join_deduplicate_all_types():
27682770
assert outer_result.to_pydict() == expected_outer
27692771

27702772
# Verify that we can still select the deduplicated column without issues
2771-
for join_type in ["inner", "left", "right", "outer"]:
2773+
for join_type in ["inner", "left", "right", "full"]:
27722774
joined = left_df.join(right_df, on="id", how=join_type, deduplicate=True)
27732775
selected = joined.select(column("id"))
27742776
# Should not raise an error and should have the same number of rows
2775-
assert len(selected.collect()[0]) == len(joined.collect()[0])
2777+
assert len(selected.collect()[0]) == len(joined.collect()[0])

0 commit comments

Comments
 (0)