Skip to content

Commit 664ae2f

Browse files
committed
feat: expose join_on method
1 parent cdec202 commit 664ae2f

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

python/datafusion/dataframe.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Any, List, TYPE_CHECKING
24+
from typing import Any, List, TYPE_CHECKING, Literal
2525
from datafusion.record_batch import RecordBatchStream
2626
from typing_extensions import deprecated
2727
from datafusion.plan import LogicalPlan, ExecutionPlan
@@ -293,6 +293,29 @@ def join(
293293
"""
294294
return DataFrame(self.df.join(right.df, join_keys, how))
295295

296+
def join_on(
297+
self,
298+
right: DataFrame,
299+
*on_exprs: Expr,
300+
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
301+
) -> DataFrame:
302+
"""Join two :py:class:`DataFrame`using the specified expressions.
303+
304+
On expressions are used to support in-equality predicates. Equality
305+
predicates are correctly optimized
306+
307+
Args:
308+
right: Other DataFrame to join with.
309+
on_exprs: single or multiple (in)-equality predicates.
310+
how: Type of join to perform. Supported types are "inner", "left",
311+
"right", "full", "semi", "anti".
312+
313+
Returns:
314+
DataFrame after join.
315+
"""
316+
exprs = [expr.expr for expr in on_exprs]
317+
return DataFrame(self.df.join_on(right.df, exprs, how))
318+
296319
def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
297320
"""Return a DataFrame with the explanation of its plan so far.
298321

python/tests/test_dataframe.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,43 @@ def test_join():
259259
assert table.to_pydict() == expected
260260

261261

262+
def test_join_on():
263+
ctx = SessionContext()
264+
265+
batch = pa.RecordBatch.from_arrays(
266+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
267+
names=["a", "b"],
268+
)
269+
df = ctx.create_dataframe([[batch]], "l")
270+
271+
batch = pa.RecordBatch.from_arrays(
272+
[pa.array([1, 2]), pa.array([8, 10])],
273+
names=["a", "c"],
274+
)
275+
df1 = ctx.create_dataframe([[batch]], "r")
276+
277+
df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
278+
df2.show()
279+
df2 = df2.sort(column("l.a"))
280+
table = pa.Table.from_batches(df2.collect())
281+
282+
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
283+
assert table.to_pydict() == expected
284+
285+
df3 = df.join_on(
286+
df1,
287+
column("l.a").__eq__(column("r.a")),
288+
column("l.a").__lt__(column("r.c")),
289+
how="inner",
290+
)
291+
df3.show()
292+
df3 = df3.sort(column("l.a"))
293+
table = pa.Table.from_batches(df3.collect())
294+
295+
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
296+
assert table.to_pydict() == expected
297+
298+
262299
def test_distinct():
263300
ctx = SessionContext()
264301

src/dataframe.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,31 @@ impl PyDataFrame {
293293
Ok(Self::new(df))
294294
}
295295

296+
fn join_on(&self, right: PyDataFrame, on_exprs: Vec<PyExpr>, how: &str) -> PyResult<Self> {
297+
let join_type = match how {
298+
"inner" => JoinType::Inner,
299+
"left" => JoinType::Left,
300+
"right" => JoinType::Right,
301+
"full" => JoinType::Full,
302+
"semi" => JoinType::LeftSemi,
303+
"anti" => JoinType::LeftAnti,
304+
how => {
305+
return Err(DataFusionError::Common(format!(
306+
"The join type {how} does not exist or is not implemented"
307+
))
308+
.into());
309+
}
310+
};
311+
let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
312+
313+
let df = self
314+
.df
315+
.as_ref()
316+
.clone()
317+
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
318+
Ok(Self::new(df))
319+
}
320+
296321
/// Print the query plan
297322
#[pyo3(signature = (verbose=false, analyze=false))]
298323
fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> {

0 commit comments

Comments
 (0)