Skip to content

Commit 0bb81df

Browse files
committed
feat: enhance join operation preparation with JoinPreparation class
1 parent 19d69ca commit 0bb81df

File tree

1 file changed

+82
-61
lines changed

1 file changed

+82
-61
lines changed

python/datafusion/dataframe.py

Lines changed: 82 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class JoinKeys:
6868
right_names: list[str]
6969

7070

71+
@dataclass
72+
class JoinPreparation:
73+
"""Represents the complete preparation for a DataFrame join operation."""
74+
75+
join_keys: JoinKeys
76+
modified_right: DataFrame
77+
drop_cols: list[str]
78+
79+
7180
# excerpt from deltalake
7281
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
7382
class Compression(Enum):
@@ -709,32 +718,51 @@ def join(
709718
Returns:
710719
DataFrame after join.
711720
"""
712-
join_keys_resolved = self._resolve_join_keys(
713-
on, left_on, right_on, join_keys
721+
join_preparation = self._prepare_join(
722+
right, on, left_on, right_on, join_keys, deduplicate
714723
)
715724

716-
drop_cols: list[str] | None = None
717-
if deduplicate and join_keys_resolved.on is not None:
718-
right, drop_cols, left_on_final, right_on_final = self._prepare_deduplicate(
719-
right, join_keys_resolved.on
725+
result = DataFrame(
726+
self.df.join(
727+
join_preparation.modified_right.df,
728+
how,
729+
join_preparation.join_keys.left_names,
730+
join_preparation.join_keys.right_names,
720731
)
721-
else:
722-
left_on_final = join_keys_resolved.left_names
723-
right_on_final = join_keys_resolved.right_names
724-
725-
result = DataFrame(self.df.join(right.df, how, left_on_final, right_on_final))
726-
if drop_cols:
727-
result = result.drop(*drop_cols)
732+
)
733+
734+
if join_preparation.drop_cols:
735+
result = result.drop(*join_preparation.drop_cols)
736+
728737
return result
729738

730-
def _resolve_join_keys(
739+
def _prepare_join(
731740
self,
741+
right: DataFrame,
732742
on: str | Sequence[str] | tuple[list[str], list[str]] | None,
733743
left_on: str | Sequence[str] | None,
734744
right_on: str | Sequence[str] | None,
735745
join_keys: tuple[list[str], list[str]] | None,
736-
) -> JoinKeys:
737-
"""Normalize join key arguments and validate them."""
746+
deduplicate: bool,
747+
) -> JoinPreparation:
748+
"""Prepare join keys and handle deduplication if requested.
749+
750+
This method combines join key resolution and deduplication preparation
751+
to avoid parameter handling duplication and provide a unified interface.
752+
753+
Args:
754+
right: The right DataFrame to join with.
755+
on: Column names to join on in both dataframes.
756+
left_on: Join column of the left dataframe.
757+
right_on: Join column of the right dataframe.
758+
join_keys: Tuple of two lists of column names to join on. [Deprecated]
759+
deduplicate: If True, prepare right DataFrame for column deduplication.
760+
761+
Returns:
762+
JoinPreparation containing resolved join keys, modified right DataFrame,
763+
and columns to drop after joining.
764+
"""
765+
# Step 1: Resolve join keys
738766
# Handle the special case where on is a tuple of lists (legacy format)
739767
resolved_on: str | Sequence[str] | None
740768
if (
@@ -785,57 +813,50 @@ def _resolve_join_keys(
785813

786814
left_names = [left_on] if isinstance(left_on, str) else list(left_on)
787815
right_names = [right_on] if isinstance(right_on, str) else list(right_on)
788-
789-
return JoinKeys(on=resolved_on, left_names=left_names, right_names=right_names)
790-
791-
def _prepare_deduplicate(
792-
self, right: DataFrame, on: str | Sequence[str]
793-
) -> tuple[DataFrame, list[str], list[str], list[str]]:
794-
"""Rename join columns to drop them after joining.
795816

796-
Uses collision-safe temporary aliases to avoid conflicts with existing column names.
817+
join_keys_resolved = JoinKeys(
818+
on=resolved_on, left_names=left_names, right_names=right_names
819+
)
797820

798-
Args:
799-
right: The right DataFrame to modify.
800-
on: The join column name(s).
801-
802-
Returns:
803-
A tuple containing:
804-
- modified_right: DataFrame with renamed join columns
805-
- drop_cols: List of column names to drop after joining
806-
- left_cols: List of original left DataFrame column names
807-
- right_aliases: List of renamed right DataFrame column names
808-
"""
821+
# Step 2: Handle deduplication if requested
809822
drop_cols: list[str] = []
810-
right_aliases: list[str] = []
811-
on_cols = [on] if isinstance(on, str) else list(on)
812-
813-
# Get existing column names to avoid collisions
814-
existing_columns = set(right.schema().names)
815-
816823
modified_right = right
817-
for col_name in on_cols:
818-
# Generate a collision-safe temporary alias
819-
base_alias = f"__right_{col_name}"
820-
alias = base_alias
821-
counter = 0
822-
823-
# Keep trying until we find a unique name
824-
while alias in existing_columns:
825-
counter += 1
826-
alias = f"{base_alias}_{counter}"
827-
828-
# If even that fails (very unlikely), use UUID
829-
if alias in existing_columns:
830-
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"
824+
825+
if deduplicate and resolved_on is not None:
826+
# Prepare deduplication by renaming columns in the right DataFrame
827+
on_cols = [resolved_on] if isinstance(resolved_on, str) else list(resolved_on)
831828

832-
modified_right = modified_right.with_column_renamed(col_name, alias)
833-
right_aliases.append(alias)
834-
drop_cols.append(alias)
835-
# Add the new alias to existing columns to avoid future collisions
836-
existing_columns.add(alias)
829+
# Get existing column names to avoid collisions
830+
existing_columns = set(right.schema().names)
837831

838-
return modified_right, drop_cols, on_cols, right_aliases
832+
for col_name in on_cols:
833+
# Generate a collision-safe temporary alias
834+
base_alias = f"__right_{col_name}"
835+
alias = base_alias
836+
counter = 0
837+
838+
# Keep trying until we find a unique name
839+
while alias in existing_columns:
840+
counter += 1
841+
alias = f"{base_alias}_{counter}"
842+
843+
# If even that fails (very unlikely), use UUID
844+
if alias in existing_columns:
845+
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"
846+
847+
modified_right = modified_right.with_column_renamed(col_name, alias)
848+
drop_cols.append(alias)
849+
# Add the new alias to existing columns to avoid future collisions
850+
existing_columns.add(alias)
851+
852+
# Update right_names to use the new aliases
853+
right_names = drop_cols.copy()
854+
855+
return JoinPreparation(
856+
join_keys=join_keys_resolved,
857+
modified_right=modified_right,
858+
drop_cols=drop_cols,
859+
)
839860

840861
def join_on(
841862
self,

0 commit comments

Comments
 (0)