From d36f4249fc6be5220697a7965ad881371ede434e Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 11:15:43 +0200 Subject: [PATCH 1/8] refactor: dataframe join params --- python/datafusion/dataframe.py | 79 +++++++++++++++++++++++++++++++--- python/tests/test_dataframe.py | 56 ++++++++++++++++++++++-- src/dataframe.rs | 9 ++-- 3 files changed, 129 insertions(+), 15 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c5ac0bb89..2494420b7 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -20,8 +20,8 @@ """ from __future__ import annotations - -from typing import Any, List, TYPE_CHECKING +import warnings +from typing import Any, List, TYPE_CHECKING, Literal, overload from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated from datafusion.plan import LogicalPlan, ExecutionPlan @@ -31,7 +31,7 @@ import pandas as pd import polars as pl import pathlib - from typing import Callable + from typing import Callable, Sequence from datafusion._internal import DataFrame as DataFrameInternal from datafusion.expr import Expr, SortExpr, sort_or_default @@ -271,11 +271,51 @@ def distinct(self) -> DataFrame: """ return DataFrame(self.df.distinct()) + @overload + def join( + self, + right: DataFrame, + on: str | Sequence[str], + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: None = None, + right_on: None = None, + join_keys: None = None, + ) -> DataFrame: ... + + @overload def join( self, right: DataFrame, + on: None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: str | Sequence[str], + right_on: str | Sequence[str], + join_keys: tuple[list[str], list[str]] | None = None, + ) -> DataFrame: ... + + @overload + def join( + self, + right: DataFrame, + on: None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, join_keys: tuple[list[str], list[str]], - how: str, + left_on: None = None, + right_on: None = None, + ) -> DataFrame: ... + + def join( + self, + right: DataFrame, + on: str | Sequence[str] | None = None, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + *, + left_on: str | Sequence[str] | None = None, + right_on: str | Sequence[str] | None = None, + join_keys: tuple[list[str], list[str]] | None = None, ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. @@ -284,14 +324,41 @@ def join( Args: right: Other DataFrame to join with. - join_keys: Tuple of two lists of column names to join on. + on: Column names to join on in both dataframes. how: Type of join to perform. Supported types are "inner", "left", "right", "full", "semi", "anti". + left_on: Join column of the left dataframe. + right_on: Join column of the right dataframe. + join_keys: Tuple of two lists of column names to join on. [Deprecated] Returns: DataFrame after join. """ - return DataFrame(self.df.join(right.df, join_keys, how)) + if join_keys is not None: + warnings.warn( + "`join_keys` is deprecated, use `on` or `left_on` with `right_on`", + category=DeprecationWarning, + stacklevel=2, + ) + left_on = join_keys[0] + right_on = join_keys[1] + + if on: + if left_on or right_on: + raise ValueError( + "`left_on` or `right_on` should not provided with `on`" + ) + left_on = on + right_on = on + elif left_on or right_on: + if left_on is None or right_on is None: + raise ValueError("`left_on` and `right_on` should both be provided.") + else: + raise ValueError( + "either `on` or `left_on` and `right_on` should be provided." + ) + + return DataFrame(self.df.join(right.df, how, left_on, right_on)) def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: """Return a DataFrame with the explanation of its plan so far. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index e89c57159..535656a89 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -250,15 +250,63 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]], "r") - df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") - df.show() - df = df.sort(column("l.a")) - table = pa.Table.from_batches(df.collect()) + df2 = df.join(df1, on="a", how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + df2 = df.join(df1, left_on="a", right_on="a", how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected +def test_join_invalid_params(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], "l") + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]], "r") + + with pytest.deprecated_call(): + df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + with pytest.raises( + ValueError, match=r"`left_on` or `right_on` should not provided with `on`" + ): + df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore + + with pytest.raises( + ValueError, match=r"`left_on` and `right_on` should both be provided." + ): + df2 = df.join(df1, left_on="a", how="inner") # type: ignore + + with pytest.raises( + ValueError, match=r"either `on` or `left_on` and `right_on` should be provided." + ): + df2 = df.join(df1, how="inner") # type: ignore + + def test_distinct(): ctx = SessionContext() diff --git a/src/dataframe.rs b/src/dataframe.rs index e77ca8425..dfffaa9e5 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -254,8 +254,9 @@ impl PyDataFrame { fn join( &self, right: PyDataFrame, - join_keys: (Vec, Vec), how: &str, + left_on: Vec, + right_on: Vec, ) -> PyResult { let join_type = match how { "inner" => JoinType::Inner, @@ -272,13 +273,11 @@ impl PyDataFrame { } }; - let left_keys = join_keys - .0 + let left_keys = left_on .iter() .map(|s| s.as_ref()) .collect::>(); - let right_keys = join_keys - .1 + let right_keys = right_on .iter() .map(|s| s.as_ref()) .collect::>(); From bb83a74234e8b20c8c8062f4df732d00be80121c Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:41:42 +0200 Subject: [PATCH 2/8] chore: add description for on params --- python/datafusion/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 2494420b7..ae509e09f 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -319,8 +319,7 @@ def join( ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. - Join keys are a pair of lists of column names in the left and right - dataframes, respectively. These lists must have the same length. + `on` has to be provided or both `left_on` and `right_on` in conjunction. Args: right: Other DataFrame to join with. From 37d1f738726027c54c2a81b24d73e7ba97280650 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Fri, 1 Nov 2024 13:31:11 +0100 Subject: [PATCH 3/8] fix type --- python/datafusion/dataframe.py | 4 ++++ src/dataframe.rs | 10 ++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 9b87f967e..3e768611c 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -452,6 +452,10 @@ def join( raise ValueError( "either `on` or `left_on` and `right_on` should be provided." ) + if isinstance(left_on, str): + left_on = [left_on] + if isinstance(right_on, str): + right_on = [right_on] return DataFrame(self.df.join(right.df, how, left_on, right_on)) diff --git a/src/dataframe.rs b/src/dataframe.rs index 648d3cf33..ee8fbbf9d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -290,14 +290,8 @@ impl PyDataFrame { } }; - let left_keys = left_on - .iter() - .map(|s| s.as_ref()) - .collect::>(); - let right_keys = right_on - .iter() - .map(|s| s.as_ref()) - .collect::>(); + let left_keys = left_on.iter().map(|s| s.as_ref()).collect::>(); + let right_keys = right_on.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().join( right.df.as_ref().clone(), From 3396ba5335e8a1a48c5c5dc5e492ec37026a0a59 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Fri, 1 Nov 2024 13:47:01 +0100 Subject: [PATCH 4/8] chore: change join param --- docs/source/user-guide/common-operations/joins.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst index 09fa145a7..40d922150 100644 --- a/docs/source/user-guide/common-operations/joins.rst +++ b/docs/source/user-guide/common-operations/joins.rst @@ -56,7 +56,7 @@ will be included in the resulting DataFrame. .. ipython:: python - left.join(right, join_keys=(["customer_id"], ["id"]), how="inner") + left.join(right, left_on="customer_id", right_on="id", how="inner") The parameter ``join_keys`` specifies the columns from the left DataFrame and right DataFrame that contains the values that should match. @@ -70,7 +70,7 @@ values for the corresponding columns. .. ipython:: python - left.join(right, join_keys=(["customer_id"], ["id"]), how="left") + left.join(right, left_on="customer_id", right_on="id", how="left") Full Join --------- @@ -80,7 +80,7 @@ is no match. Unmatched rows will have null values. .. ipython:: python - left.join(right, join_keys=(["customer_id"], ["id"]), how="full") + left.join(right, left_on="customer_id", right_on="id", how="full") Left Semi Join -------------- @@ -90,7 +90,7 @@ omitting duplicates with multiple matches in the right table. .. ipython:: python - left.join(right, join_keys=(["customer_id"], ["id"]), how="semi") + left.join(right, left_on="customer_id", right_on="id", how="semi") Left Anti Join -------------- @@ -101,4 +101,4 @@ the right table. .. ipython:: python - left.join(right, join_keys=(["customer_id"], ["id"]), how="anti") \ No newline at end of file + left.join(right, left_on="customer_id", right_on="id", how="anti") \ No newline at end of file From ab36082da2af849121fce1688b47132ec5ce3fac Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:49:49 +0100 Subject: [PATCH 5/8] chore: update join params in tpch --- examples/tpch/_tests.py | 6 +++--- examples/tpch/q02_minimum_cost_supplier.py | 12 ++++++++---- examples/tpch/q03_shipping_priority.py | 6 +++--- examples/tpch/q04_order_priority_checking.py | 4 +++- examples/tpch/q05_local_supplier_volume.py | 13 ++++++++----- examples/tpch/q07_volume_shipping.py | 12 +++++++----- examples/tpch/q08_market_share.py | 14 +++++++------- examples/tpch/q09_product_type_profit_measure.py | 13 ++++++++----- examples/tpch/q10_returned_item_reporting.py | 6 +++--- .../tpch/q11_important_stock_identification.py | 6 ++++-- examples/tpch/q12_ship_mode_order_priority.py | 2 +- examples/tpch/q13_customer_distribution.py | 4 +++- examples/tpch/q14_promotion_effect.py | 4 +++- examples/tpch/q15_top_supplier.py | 2 +- examples/tpch/q16_part_supplier_relationship.py | 2 +- examples/tpch/q17_small_quantity_order.py | 2 +- examples/tpch/q18_large_volume_customer.py | 4 ++-- examples/tpch/q19_discounted_revenue.py | 2 +- examples/tpch/q20_potential_part_promotion.py | 11 +++++++---- examples/tpch/q21_suppliers_kept_orders_waiting.py | 4 ++-- examples/tpch/q22_global_sales_opportunity.py | 2 +- 21 files changed, 77 insertions(+), 54 deletions(-) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index 903b53548..13144ae9d 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -18,7 +18,7 @@ import pytest from importlib import import_module import pyarrow as pa -from datafusion import col, lit, functions as F +from datafusion import DataFrame, col, lit, functions as F from util import get_answer_file @@ -94,7 +94,7 @@ def check_q17(df): ) def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): module = import_module(query_code) - df = module.df + df: DataFrame = module.df # Treat q17 as a special case. The answer file does not match the spec. # Running at scale factor 1, we have manually verified this result does @@ -121,5 +121,5 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): cols = list(read_schema.names) - assert df.join(df_expected, (cols, cols), "anti").count() == 0 + assert df.join(df_expected, on=cols, how="anti").count() == 0 assert df.count() == df_expected.count() diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 2440fdad6..c4ccf8ad3 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -80,16 +80,20 @@ # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner") +df_nation = df_nation.join( + df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" +) df_supplier = df_supplier.join( - df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" ) # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), how="inner") +df = df_partsupp.join( + df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" +) # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -111,7 +115,7 @@ df = df.filter(col("min_cost") == col("ps_supplycost")) -df = df.join(df_part, (["ps_partkey"], ["p_partkey"]), how="inner") +df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") # From the problem statement, these are the values we wish to output diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index c4e8f461a..5ebab13c0 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -55,9 +55,9 @@ # Join all 3 dataframes -df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner").join( - df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner" -) +df = df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" +).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") # Compute the revenue diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index f10b74d91..8bf02cb83 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -66,7 +66,9 @@ ) # Perform the join to find only orders for which there are lineitems outside of expected range -df = df_orders.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") +df = df_orders.join( + df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" +) # Based on priority, find the number of entries df = df.aggregate( diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index 2a83d2d1a..413a4acb9 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -76,15 +76,18 @@ # Join all the dataframes df = ( - df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner") - .join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") + df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" + ) + .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") .join( df_supplier, - (["l_suppkey", "c_nationkey"], ["s_suppkey", "s_nationkey"]), + left_on=["l_suppkey", "c_nationkey"], + right_on=["s_suppkey", "s_nationkey"], how="inner", ) - .join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner") - .join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner") + .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") + .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") ) # Compute the final result diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index a1d7d81ad..18c290d9c 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -90,20 +90,22 @@ # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" ).select(col("s_suppkey"), col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" ).select(col("c_custkey"), col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") - .join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner") - .join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner") + df_lineitem.join( + df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" + ) + .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") + .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") .filter(col("cust_nation") != col("supp_nation")) ) diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 95fc0a871..7138ab65a 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -89,27 +89,27 @@ # After this join we have all of the possible sales nations df_regional_customers = df_regional_customers.join( - df_nation, (["r_regionkey"], ["n_regionkey"]), how="inner" + df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" ) # Now find the possible customers df_regional_customers = df_regional_customers.join( - df_customer, (["n_nationkey"], ["c_nationkey"]), how="inner" + df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" ) # Next find orders for these customers df_regional_customers = df_regional_customers.join( - df_orders, (["c_custkey"], ["o_custkey"]), how="inner" + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" ) # Find all line items from these orders df_regional_customers = df_regional_customers.join( - df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner" + df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" ) # Limit to the part of interest df_regional_customers = df_regional_customers.join( - df_part, (["l_partkey"], ["p_partkey"]), how="inner" + df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" ) # Compute the volume for each line item @@ -126,7 +126,7 @@ # Determine the suppliers by the limited nation key we have in our single row df above df_national_suppliers = df_national_suppliers.join( - df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner" + df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" ) # When we join to the customer dataframe, we don't want to confuse other columns, so only @@ -141,7 +141,7 @@ # column only from suppliers in the nation we are evaluating. df = df_regional_customers.join( - df_national_suppliers, (["l_suppkey"], ["s_suppkey"]), how="left" + df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" ) # Use a case statement to compute the volume sold by suppliers in the nation of interest diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 0295d3025..aa47d76c0 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -65,13 +65,16 @@ df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) # We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), how="inner") -df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner") -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") +df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") df = df.join( - df_partsupp, (["l_suppkey", "l_partkey"], ["ps_suppkey", "ps_partkey"]), how="inner" + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + how="inner", ) -df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner") +df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index 25f81b2ff..94b398c1d 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -74,7 +74,7 @@ col("o_orderdate") < date_start_of_quarter + interval_one_quarter ) -df = df.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") +df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") # Compute the revenue df = df.aggregate( @@ -83,8 +83,8 @@ ) # Now join in the customer data -df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner") -df = df.join(df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner") +df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") +df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") # These are the columns the problem statement requires df = df.select( diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 86ff2296b..707265e16 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -52,9 +52,11 @@ # Find part supplies of within this target nation -df = df_nation.join(df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner") +df = df_nation.join( + df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +) -df = df.join(df_partsupp, (["s_suppkey"], ["ps_suppkey"]), how="inner") +df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") # Compute the value of individual parts diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index c3fc0d2e9..def2a6c30 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -75,7 +75,7 @@ # We need order priority, so join order df to line item -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") # Restrict to line items we care about based on the problem statement. df = df.filter(col("l_commitdate") < col("l_receiptdate")) diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index f8b6c139d..67365a96a 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -49,7 +49,9 @@ ) # Since we may have customers with no orders we must do a left join -df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="left") +df = df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" +) # Find the number of orders for each customer df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 8224136ad..cd26ee2bd 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -57,7 +57,9 @@ ) # Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join(df_part, (["l_partkey"], ["p_partkey"]), "left") +df = df_lineitem.join( + df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" +) # Make a factor of 1.0 if it is a promotion, 0.0 otherwise df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 44d5dd997..0bc316f7a 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -76,7 +76,7 @@ # Now that we know the supplier(s) with maximum revenue, get the rest of their information # from the supplier table -df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), "inner") +df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") # Return only the columns requested df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index cbdd9989a..dabebaedf 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -56,7 +56,7 @@ # Remove unwanted suppliers df_partsupp = df_partsupp.join( - df_unwanted_suppliers, (["ps_suppkey"], ["s_suppkey"]), "anti" + df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" ) # Select the parts we are interested in diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index ff494279b..d7b43d498 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -51,7 +51,7 @@ ) # Combine data -df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), "inner") +df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") # Find the average quantity window_frame = WindowFrame("rows", None, None) diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 497615499..165fce033 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -54,8 +54,8 @@ # We've identified the orders of interest, now join the additional data # we are required to report on -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), "inner") -df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), "inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") df = df.select( "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index c2fe2570d..4aed0cbae 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -72,7 +72,7 @@ (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) ) -df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner") +df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") # Create the user defined function (UDF) definition that does the work diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 3a0edb1ec..d720cdce6 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -70,7 +70,7 @@ ) # This will filter down the line items to the parts of interest -df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner") +df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") # Compute the total sold and limit ourselves to individual supplier/part combinations df = df.aggregate( @@ -78,15 +78,18 @@ ) df = df.join( - df_partsupp, (["l_partkey", "l_suppkey"], ["ps_partkey", "ps_suppkey"]), "inner" + df_partsupp, + left_on=["l_partkey", "l_suppkey"], + right_on=["ps_partkey", "ps_suppkey"], + how="inner", ) # Find cases of excess quantity df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) # We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), "inner") -df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), "inner") +df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") +df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Restrict to the requested data per the problem statement df = df.select("s_name", "s_address").distinct() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index d3d57acee..991b88eb3 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -52,13 +52,13 @@ df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, (["n_nationkey"], ["s_nationkey"]), "inner" + df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" ) # Find the failed orders and all their line items df = df_orders.filter(col("o_orderstatus") == lit("F")) -df = df_lineitem.join(df, (["l_orderkey"], ["o_orderkey"]), "inner") +df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") # Identify the line items for which the order is failed due to. df = df.with_column( diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index e6660e60c..72dce5289 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -62,7 +62,7 @@ df = df.filter(col("c_acctbal") > col("avg_balance")) # Limit results to customers with no orders -df = df.join(df_orders, (["c_custkey"], ["o_custkey"]), "anti") +df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") # Count up the customers and the balances df = df.aggregate( From ef5aeb3e633573e70980e92662c28ee33f2520ce Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 2 Nov 2024 11:10:56 +0100 Subject: [PATCH 6/8] oops --- examples/tpch/q21_suppliers_kept_orders_waiting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index 991b88eb3..27cf816fa 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -102,7 +102,9 @@ ) # Join to the supplier of interest list for the nation of interest -df = df.join(df_suppliers_of_interest, (["suppkey"], ["s_suppkey"]), "inner") +df = df.join( + df_suppliers_of_interest, left_on=["suppkey"], right_on=["s_suppkey"], how="inner" +) # Count how many orders that supplier is the only failed supplier for df = df.aggregate([col("s_name")], [F.count(col("o_orderkey")).alias("numwait")]) From 99aba0b6245cf999a7b124ca9d66bcb55270c662 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 2 Nov 2024 11:27:29 +0100 Subject: [PATCH 7/8] chore: final change --- examples/tpch/q16_part_supplier_relationship.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index dabebaedf..a6a0c43eb 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -73,7 +73,9 @@ p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST]) df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null()) -df = df_part.join(df_partsupp, (["p_partkey"], ["ps_partkey"]), "inner") +df = df_part.join( + df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner" +) df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct() From 8cc7bd6269e99320156d1e9ec9455ab0249ce3eb Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 8 Nov 2024 07:29:14 -0500 Subject: [PATCH 8/8] Add support for join_keys as a positional argument --- python/datafusion/dataframe.py | 10 +++++++++- python/tests/test_dataframe.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 3e768611c..efd4038ae 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -406,7 +406,7 @@ def join( def join( self, right: DataFrame, - on: str | Sequence[str] | None = None, + on: str | Sequence[str] | tuple[list[str], list[str]] | None = None, how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", *, left_on: str | Sequence[str] | None = None, @@ -429,6 +429,14 @@ def join( Returns: DataFrame after join. """ + # This check is to prevent breaking API changes where users prior to + # DF 43.0.0 would pass the join_keys as a positional argument instead + # of a keyword argument. + if isinstance(on, tuple) and len(on) == 2: + if isinstance(on[0], list) and isinstance(on[1], list): + join_keys = on # type: ignore + on = None + if join_keys is not None: warnings.warn( "`join_keys` is deprecated, use `on` or `left_on` with `right_on`", diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 77c8a141d..330475302 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -337,6 +337,16 @@ def test_join(): expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected + # Verify we don't make a breaking change to pre-43.0.0 + # where users would pass join_keys as a positional argument + df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + def test_join_invalid_params(): ctx = SessionContext()