Skip to content

Commit d36f424

Browse files
committed
refactor: dataframe join params
1 parent cdec202 commit d36f424

File tree

3 files changed

+129
-15
lines changed

3 files changed

+129
-15
lines changed

python/datafusion/dataframe.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
"""
2121

2222
from __future__ import annotations
23-
24-
from typing import Any, List, TYPE_CHECKING
23+
import warnings
24+
from typing import Any, List, TYPE_CHECKING, Literal, overload
2525
from datafusion.record_batch import RecordBatchStream
2626
from typing_extensions import deprecated
2727
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -31,7 +31,7 @@
3131
import pandas as pd
3232
import polars as pl
3333
import pathlib
34-
from typing import Callable
34+
from typing import Callable, Sequence
3535

3636
from datafusion._internal import DataFrame as DataFrameInternal
3737
from datafusion.expr import Expr, SortExpr, sort_or_default
@@ -271,11 +271,51 @@ def distinct(self) -> DataFrame:
271271
"""
272272
return DataFrame(self.df.distinct())
273273

274+
@overload
275+
def join(
276+
self,
277+
right: DataFrame,
278+
on: str | Sequence[str],
279+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
280+
*,
281+
left_on: None = None,
282+
right_on: None = None,
283+
join_keys: None = None,
284+
) -> DataFrame: ...
285+
286+
@overload
274287
def join(
275288
self,
276289
right: DataFrame,
290+
on: None = None,
291+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
292+
*,
293+
left_on: str | Sequence[str],
294+
right_on: str | Sequence[str],
295+
join_keys: tuple[list[str], list[str]] | None = None,
296+
) -> DataFrame: ...
297+
298+
@overload
299+
def join(
300+
self,
301+
right: DataFrame,
302+
on: None = None,
303+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
304+
*,
277305
join_keys: tuple[list[str], list[str]],
278-
how: str,
306+
left_on: None = None,
307+
right_on: None = None,
308+
) -> DataFrame: ...
309+
310+
def join(
311+
self,
312+
right: DataFrame,
313+
on: str | Sequence[str] | None = None,
314+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
315+
*,
316+
left_on: str | Sequence[str] | None = None,
317+
right_on: str | Sequence[str] | None = None,
318+
join_keys: tuple[list[str], list[str]] | None = None,
279319
) -> DataFrame:
280320
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
281321
@@ -284,14 +324,41 @@ def join(
284324
285325
Args:
286326
right: Other DataFrame to join with.
287-
join_keys: Tuple of two lists of column names to join on.
327+
on: Column names to join on in both dataframes.
288328
how: Type of join to perform. Supported types are "inner", "left",
289329
"right", "full", "semi", "anti".
330+
left_on: Join column of the left dataframe.
331+
right_on: Join column of the right dataframe.
332+
join_keys: Tuple of two lists of column names to join on. [Deprecated]
290333
291334
Returns:
292335
DataFrame after join.
293336
"""
294-
return DataFrame(self.df.join(right.df, join_keys, how))
337+
if join_keys is not None:
338+
warnings.warn(
339+
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
340+
category=DeprecationWarning,
341+
stacklevel=2,
342+
)
343+
left_on = join_keys[0]
344+
right_on = join_keys[1]
345+
346+
if on:
347+
if left_on or right_on:
348+
raise ValueError(
349+
"`left_on` or `right_on` should not provided with `on`"
350+
)
351+
left_on = on
352+
right_on = on
353+
elif left_on or right_on:
354+
if left_on is None or right_on is None:
355+
raise ValueError("`left_on` and `right_on` should both be provided.")
356+
else:
357+
raise ValueError(
358+
"either `on` or `left_on` and `right_on` should be provided."
359+
)
360+
361+
return DataFrame(self.df.join(right.df, how, left_on, right_on))
295362

296363
def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
297364
"""Return a DataFrame with the explanation of its plan so far.

python/tests/test_dataframe.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,63 @@ def test_join():
250250
)
251251
df1 = ctx.create_dataframe([[batch]], "r")
252252

253-
df = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
254-
df.show()
255-
df = df.sort(column("l.a"))
256-
table = pa.Table.from_batches(df.collect())
253+
df2 = df.join(df1, on="a", how="inner")
254+
df2.show()
255+
df2 = df2.sort(column("l.a"))
256+
table = pa.Table.from_batches(df2.collect())
257+
258+
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
259+
assert table.to_pydict() == expected
260+
261+
df2 = df.join(df1, left_on="a", right_on="a", how="inner")
262+
df2.show()
263+
df2 = df2.sort(column("l.a"))
264+
table = pa.Table.from_batches(df2.collect())
257265

258266
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
259267
assert table.to_pydict() == expected
260268

261269

270+
def test_join_invalid_params():
271+
ctx = SessionContext()
272+
273+
batch = pa.RecordBatch.from_arrays(
274+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
275+
names=["a", "b"],
276+
)
277+
df = ctx.create_dataframe([[batch]], "l")
278+
279+
batch = pa.RecordBatch.from_arrays(
280+
[pa.array([1, 2]), pa.array([8, 10])],
281+
names=["a", "c"],
282+
)
283+
df1 = ctx.create_dataframe([[batch]], "r")
284+
285+
with pytest.deprecated_call():
286+
df2 = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
287+
df2.show()
288+
df2 = df2.sort(column("l.a"))
289+
table = pa.Table.from_batches(df2.collect())
290+
291+
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
292+
assert table.to_pydict() == expected
293+
294+
with pytest.raises(
295+
ValueError, match=r"`left_on` or `right_on` should not provided with `on`"
296+
):
297+
df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore
298+
299+
with pytest.raises(
300+
ValueError, match=r"`left_on` and `right_on` should both be provided."
301+
):
302+
df2 = df.join(df1, left_on="a", how="inner") # type: ignore
303+
304+
with pytest.raises(
305+
ValueError, match=r"either `on` or `left_on` and `right_on` should be provided."
306+
):
307+
df2 = df.join(df1, how="inner") # type: ignore
308+
309+
262310
def test_distinct():
263311
ctx = SessionContext()
264312

src/dataframe.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,9 @@ impl PyDataFrame {
254254
fn join(
255255
&self,
256256
right: PyDataFrame,
257-
join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
258257
how: &str,
258+
left_on: Vec<PyBackedStr>,
259+
right_on: Vec<PyBackedStr>,
259260
) -> PyResult<Self> {
260261
let join_type = match how {
261262
"inner" => JoinType::Inner,
@@ -272,13 +273,11 @@ impl PyDataFrame {
272273
}
273274
};
274275

275-
let left_keys = join_keys
276-
.0
276+
let left_keys = left_on
277277
.iter()
278278
.map(|s| s.as_ref())
279279
.collect::<Vec<&str>>();
280-
let right_keys = join_keys
281-
.1
280+
let right_keys = right_on
282281
.iter()
283282
.map(|s| s.as_ref())
284283
.collect::<Vec<&str>>();

0 commit comments

Comments
 (0)