Skip to content

Commit 94e3b3e

Browse files
GH-46572: [Python] expose filter option to python for join (#46566)
### Rationale for this change C++ implementation support filter while performing hash join, however, it didn't expose to python and I think it's good to have this, so other users can avoid additional filter op explicitly in their side. ### What changes are included in this PR? Support filter expression in python binding. ### Are these changes tested? Yes, added new test `test_hash_join_with_filter`. ### Are there any user-facing changes? It will expose one more argument for user, i.e., filter_expression for `Table.join` and `Datastet.join` * GitHub Issue: #46572 Lead-authored-by: Xingyu Long <[email protected]> Co-authored-by: Rossi Sun <[email protected]> Signed-off-by: AlenkaF <[email protected]>
1 parent acbad29 commit 94e3b3e

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

python/pyarrow/_acero.pyx

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,15 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
273273

274274
def _set_options(
275275
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
276-
output_suffix_for_left="", output_suffix_for_right="",
276+
output_suffix_for_left="", output_suffix_for_right="", Expression filter_expression=None,
277277
):
278278
cdef:
279279
CJoinType c_join_type
280280
vector[CFieldRef] c_left_keys
281281
vector[CFieldRef] c_right_keys
282282
vector[CFieldRef] c_left_output
283283
vector[CFieldRef] c_right_output
284+
CExpression c_filter_expression
284285

285286
# join type
286287
if join_type == "left semi":
@@ -312,6 +313,11 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
312313
for key in right_keys:
313314
c_right_keys.push_back(_ensure_field_ref(key))
314315

316+
if filter_expression is None:
317+
c_filter_expression = _true
318+
else:
319+
c_filter_expression = filter_expression.unwrap()
320+
315321
# left/right output fields
316322
if left_output is not None and right_output is not None:
317323
for colname in left_output:
@@ -323,7 +329,7 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
323329
new CHashJoinNodeOptions(
324330
c_join_type, c_left_keys, c_right_keys,
325331
c_left_output, c_right_output,
326-
_true,
332+
c_filter_expression,
327333
<c_string>tobytes(output_suffix_for_left),
328334
<c_string>tobytes(output_suffix_for_right)
329335
)
@@ -332,7 +338,7 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
332338
self.wrapped.reset(
333339
new CHashJoinNodeOptions(
334340
c_join_type, c_left_keys, c_right_keys,
335-
_true,
341+
c_filter_expression,
336342
<c_string>tobytes(output_suffix_for_left),
337343
<c_string>tobytes(output_suffix_for_right)
338344
)
@@ -373,15 +379,17 @@ class HashJoinNodeOptions(_HashJoinNodeOptions):
373379
output_suffix_for_right : str
374380
Suffix added to names of output fields coming from right input,
375381
see `output_suffix_for_left` for details.
382+
filter_expression : pyarrow.compute.Expression
383+
Residual filter which is applied to matching row.
376384
"""
377385

378386
def __init__(
379387
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
380-
output_suffix_for_left="", output_suffix_for_right=""
388+
output_suffix_for_left="", output_suffix_for_right="", filter_expression=None,
381389
):
382390
self._set_options(
383391
join_type, left_keys, right_keys, left_output, right_output,
384-
output_suffix_for_left, output_suffix_for_right
392+
output_suffix_for_left, output_suffix_for_right, filter_expression
385393
)
386394

387395

python/pyarrow/acero.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _perform_join(join_type, left_operand, left_keys,
8383
right_operand, right_keys,
8484
left_suffix=None, right_suffix=None,
8585
use_threads=True, coalesce_keys=False,
86-
output_type=Table):
86+
output_type=Table, filter_expression=None):
8787
"""
8888
Perform join of two tables or datasets.
8989
@@ -114,6 +114,8 @@ def _perform_join(join_type, left_operand, left_keys,
114114
in the join result.
115115
output_type: Table or InMemoryDataset
116116
The output type for the exec plan result.
117+
filter_expression : pyarrow.compute.Expression
118+
Residual filter which is applied to matching row.
117119
118120
Returns
119121
-------
@@ -183,12 +185,14 @@ def _perform_join(join_type, left_operand, left_keys,
183185
join_type, left_keys, right_keys, left_columns, right_columns,
184186
output_suffix_for_left=left_suffix or "",
185187
output_suffix_for_right=right_suffix or "",
188+
filter_expression=filter_expression,
186189
)
187190
else:
188191
join_opts = HashJoinNodeOptions(
189192
join_type, left_keys, right_keys,
190193
output_suffix_for_left=left_suffix or "",
191194
output_suffix_for_right=right_suffix or "",
195+
filter_expression=filter_expression,
192196
)
193197
decl = Declaration(
194198
"hashjoin", options=join_opts, inputs=[left_source, right_source]

python/pyarrow/table.pxi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5634,7 +5634,7 @@ cdef class Table(_Tabular):
56345634

56355635
def join(self, right_table, keys, right_keys=None, join_type="left outer",
56365636
left_suffix=None, right_suffix=None, coalesce_keys=True,
5637-
use_threads=True):
5637+
use_threads=True, filter_expression=None):
56385638
"""
56395639
Perform a join between this table and another one.
56405640
@@ -5668,6 +5668,8 @@ cdef class Table(_Tabular):
56685668
in the join result.
56695669
use_threads : bool, default True
56705670
Whether to use multithreading or not.
5671+
filter_expression : pyarrow.compute.Expression
5672+
Residual filter which is applied to matching row.
56715673
56725674
Returns
56735675
-------
@@ -5677,6 +5679,7 @@ cdef class Table(_Tabular):
56775679
--------
56785680
>>> import pandas as pd
56795681
>>> import pyarrow as pa
5682+
>>> import pyarrow.compute as pc
56805683
>>> df1 = pd.DataFrame({'id': [1, 2, 3],
56815684
... 'year': [2020, 2022, 2019]})
56825685
>>> df2 = pd.DataFrame({'id': [3, 4],
@@ -5727,7 +5730,7 @@ cdef class Table(_Tabular):
57275730
n_legs: [[5,100]]
57285731
animal: [["Brittle stars","Centipede"]]
57295732
5730-
Right anti join
5733+
Right anti join:
57315734
57325735
>>> t1.join(t2, 'id', join_type="right anti")
57335736
pyarrow.Table
@@ -5738,6 +5741,20 @@ cdef class Table(_Tabular):
57385741
id: [[4]]
57395742
n_legs: [[100]]
57405743
animal: [["Centipede"]]
5744+
5745+
Inner join with intended mismatch filter expression:
5746+
5747+
>>> t1.join(t2, 'id', join_type="inner", filter_expression=pc.equal(pc.field("n_legs"), 100))
5748+
pyarrow.Table
5749+
id: int64
5750+
year: int64
5751+
n_legs: int64
5752+
animal: string
5753+
----
5754+
id: []
5755+
year: []
5756+
n_legs: []
5757+
animal: []
57415758
"""
57425759
self._assert_cpu()
57435760
if right_keys is None:
@@ -5746,7 +5763,8 @@ cdef class Table(_Tabular):
57465763
join_type, self, keys, right_table, right_keys,
57475764
left_suffix=left_suffix, right_suffix=right_suffix,
57485765
use_threads=use_threads, coalesce_keys=coalesce_keys,
5749-
output_type=Table
5766+
output_type=Table,
5767+
filter_expression=filter_expression,
57505768
)
57515769

57525770
def join_asof(self, right_table, on, by, tolerance, right_on=None, right_by=None):

python/pyarrow/tests/test_acero.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,70 @@ def test_hash_join():
362362
assert result.sort_by("a").equals(expected)
363363

364364

365+
def test_hash_join_with_residual_filter():
366+
left = pa.table({'key': [1, 2, 3], 'a': [4, 5, 6]})
367+
left_source = Declaration("table_source", options=TableSourceNodeOptions(left))
368+
right = pa.table({'key': [2, 3, 4], 'b': [4, 5, 6]})
369+
right_source = Declaration("table_source", options=TableSourceNodeOptions(right))
370+
371+
join_opts = HashJoinNodeOptions(
372+
"inner", left_keys="key", right_keys="key",
373+
filter_expression=pc.equal(pc.field('a'), 5))
374+
joined = Declaration(
375+
"hashjoin", options=join_opts, inputs=[left_source, right_source])
376+
result = joined.to_table()
377+
expected = pa.table(
378+
[[2], [5], [2], [4]],
379+
names=["key", "a", "key", "b"])
380+
assert result.equals(expected)
381+
382+
# test filter expression referencing columns from both side
383+
join_opts = HashJoinNodeOptions(
384+
"left outer", left_keys="key", right_keys="key",
385+
filter_expression=pc.equal(pc.field("a"), 5) | pc.equal(pc.field("b"), 10)
386+
)
387+
joined = Declaration(
388+
"hashjoin", options=join_opts, inputs=[left_source, right_source])
389+
result = joined.to_table()
390+
expected = pa.table(
391+
[[2, 1, 3], [5, 4, 6], [2, None, None], [4, None, None]],
392+
names=["key", "a", "key", "b"])
393+
assert result.equals(expected)
394+
395+
# test with always true
396+
always_true = pc.scalar(True)
397+
join_opts = HashJoinNodeOptions(
398+
"inner", left_keys="key", right_keys="key",
399+
filter_expression=always_true)
400+
joined = Declaration(
401+
"hashjoin", options=join_opts, inputs=[left_source, right_source])
402+
result = joined.to_table()
403+
expected = pa.table(
404+
[[2, 3], [5, 6], [2, 3], [4, 5]],
405+
names=["key", "a", "key", "b"]
406+
)
407+
assert result.equals(expected)
408+
409+
# test with always false
410+
always_false = pc.scalar(False)
411+
join_opts = HashJoinNodeOptions(
412+
"inner", left_keys="key", right_keys="key",
413+
filter_expression=always_false)
414+
joined = Declaration(
415+
"hashjoin", options=join_opts, inputs=[left_source, right_source])
416+
result = joined.to_table()
417+
expected = pa.table(
418+
[
419+
pa.array([], type=pa.int64()),
420+
pa.array([], type=pa.int64()),
421+
pa.array([], type=pa.int64()),
422+
pa.array([], type=pa.int64())
423+
],
424+
names=["key", "a", "key", "b"]
425+
)
426+
assert result.equals(expected)
427+
428+
365429
def test_asof_join():
366430
left = pa.table({'key': [1, 2, 3], 'ts': [1, 1, 1], 'a': [4, 5, 6]})
367431
left_source = Declaration("table_source", options=TableSourceNodeOptions(left))

0 commit comments

Comments
 (0)