@@ -68,6 +68,15 @@ class JoinKeys:
6868 right_names : list [str ]
6969
7070
71+ @dataclass
72+ class JoinPreparation :
73+ """Represents the complete preparation for a DataFrame join operation."""
74+
75+ join_keys : JoinKeys
76+ modified_right : DataFrame
77+ drop_cols : list [str ]
78+
79+
7180# excerpt from deltalake
7281# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
7382class Compression (Enum ):
@@ -709,32 +718,51 @@ def join(
709718 Returns:
710719 DataFrame after join.
711720 """
712- join_keys_resolved = self ._resolve_join_keys (
713- on , left_on , right_on , join_keys
721+ join_preparation = self ._prepare_join (
722+ right , on , left_on , right_on , join_keys , deduplicate
714723 )
715724
716- drop_cols : list [str ] | None = None
717- if deduplicate and join_keys_resolved .on is not None :
718- right , drop_cols , left_on_final , right_on_final = self ._prepare_deduplicate (
719- right , join_keys_resolved .on
725+ result = DataFrame (
726+ self .df .join (
727+ join_preparation .modified_right .df ,
728+ how ,
729+ join_preparation .join_keys .left_names ,
730+ join_preparation .join_keys .right_names ,
720731 )
721- else :
722- left_on_final = join_keys_resolved .left_names
723- right_on_final = join_keys_resolved .right_names
724-
725- result = DataFrame (self .df .join (right .df , how , left_on_final , right_on_final ))
726- if drop_cols :
727- result = result .drop (* drop_cols )
732+ )
733+
734+ if join_preparation .drop_cols :
735+ result = result .drop (* join_preparation .drop_cols )
736+
728737 return result
729738
730- def _resolve_join_keys (
739+ def _prepare_join (
731740 self ,
741+ right : DataFrame ,
732742 on : str | Sequence [str ] | tuple [list [str ], list [str ]] | None ,
733743 left_on : str | Sequence [str ] | None ,
734744 right_on : str | Sequence [str ] | None ,
735745 join_keys : tuple [list [str ], list [str ]] | None ,
736- ) -> JoinKeys :
737- """Normalize join key arguments and validate them."""
746+ deduplicate : bool ,
747+ ) -> JoinPreparation :
748+ """Prepare join keys and handle deduplication if requested.
749+
750+ This method combines join key resolution and deduplication preparation
751+ to avoid parameter handling duplication and provide a unified interface.
752+
753+ Args:
754+ right: The right DataFrame to join with.
755+ on: Column names to join on in both dataframes.
756+ left_on: Join column of the left dataframe.
757+ right_on: Join column of the right dataframe.
758+ join_keys: Tuple of two lists of column names to join on. [Deprecated]
759+ deduplicate: If True, prepare right DataFrame for column deduplication.
760+
761+ Returns:
762+ JoinPreparation containing resolved join keys, modified right DataFrame,
763+ and columns to drop after joining.
764+ """
765+ # Step 1: Resolve join keys
738766 # Handle the special case where on is a tuple of lists (legacy format)
739767 resolved_on : str | Sequence [str ] | None
740768 if (
@@ -785,57 +813,50 @@ def _resolve_join_keys(
785813
786814 left_names = [left_on ] if isinstance (left_on , str ) else list (left_on )
787815 right_names = [right_on ] if isinstance (right_on , str ) else list (right_on )
788-
789- return JoinKeys (on = resolved_on , left_names = left_names , right_names = right_names )
790-
791- def _prepare_deduplicate (
792- self , right : DataFrame , on : str | Sequence [str ]
793- ) -> tuple [DataFrame , list [str ], list [str ], list [str ]]:
794- """Rename join columns to drop them after joining.
795816
796- Uses collision-safe temporary aliases to avoid conflicts with existing column names.
817+ join_keys_resolved = JoinKeys (
818+ on = resolved_on , left_names = left_names , right_names = right_names
819+ )
797820
798- Args:
799- right: The right DataFrame to modify.
800- on: The join column name(s).
801-
802- Returns:
803- A tuple containing:
804- - modified_right: DataFrame with renamed join columns
805- - drop_cols: List of column names to drop after joining
806- - left_cols: List of original left DataFrame column names
807- - right_aliases: List of renamed right DataFrame column names
808- """
821+ # Step 2: Handle deduplication if requested
809822 drop_cols : list [str ] = []
810- right_aliases : list [str ] = []
811- on_cols = [on ] if isinstance (on , str ) else list (on )
812-
813- # Get existing column names to avoid collisions
814- existing_columns = set (right .schema ().names )
815-
816823 modified_right = right
817- for col_name in on_cols :
818- # Generate a collision-safe temporary alias
819- base_alias = f"__right_{ col_name } "
820- alias = base_alias
821- counter = 0
822-
823- # Keep trying until we find a unique name
824- while alias in existing_columns :
825- counter += 1
826- alias = f"{ base_alias } _{ counter } "
827-
828- # If even that fails (very unlikely), use UUID
829- if alias in existing_columns :
830- alias = f"__temp_{ uuid .uuid4 ().hex [:8 ]} _{ col_name } "
824+
825+ if deduplicate and resolved_on is not None :
826+ # Prepare deduplication by renaming columns in the right DataFrame
827+ on_cols = [resolved_on ] if isinstance (resolved_on , str ) else list (resolved_on )
831828
832- modified_right = modified_right .with_column_renamed (col_name , alias )
833- right_aliases .append (alias )
834- drop_cols .append (alias )
835- # Add the new alias to existing columns to avoid future collisions
836- existing_columns .add (alias )
829+ # Get existing column names to avoid collisions
830+ existing_columns = set (right .schema ().names )
837831
838- return modified_right , drop_cols , on_cols , right_aliases
832+ for col_name in on_cols :
833+ # Generate a collision-safe temporary alias
834+ base_alias = f"__right_{ col_name } "
835+ alias = base_alias
836+ counter = 0
837+
838+ # Keep trying until we find a unique name
839+ while alias in existing_columns :
840+ counter += 1
841+ alias = f"{ base_alias } _{ counter } "
842+
843+ # If even that fails (very unlikely), use UUID
844+ if alias in existing_columns :
845+ alias = f"__temp_{ uuid .uuid4 ().hex [:8 ]} _{ col_name } "
846+
847+ modified_right = modified_right .with_column_renamed (col_name , alias )
848+ drop_cols .append (alias )
849+ # Add the new alias to existing columns to avoid future collisions
850+ existing_columns .add (alias )
851+
852+ # Update right_names to use the new aliases
853+ right_names = drop_cols .copy ()
854+
855+ return JoinPreparation (
856+ join_keys = join_keys_resolved ,
857+ modified_right = modified_right ,
858+ drop_cols = drop_cols ,
859+ )
839860
840861 def join_on (
841862 self ,
0 commit comments