Skip to content

Commit a23a2e0

Browse files
committed
ENH: Add leftsemi merge
1 parent 669ddfb commit a23a2e0

File tree

9 files changed

+166
-12
lines changed

9 files changed

+166
-12
lines changed

asv_bench/benchmarks/join_merge.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ def time_merge_dataframe_empty_left(self, sort):
272272
def time_merge_dataframes_cross(self, sort):
273273
merge(self.left.loc[:2000], self.right.loc[:2000], how="cross", sort=sort)
274274

275+
def time_merge_semi(self, sort):
276+
merge(self.df, self.df2, on="key1", how="leftsemi")
277+
275278

276279
class MergeEA:
277280
params = [
@@ -380,6 +383,9 @@ def setup(self, units, tz, monotonic):
380383
def time_merge(self, units, tz, monotonic):
381384
merge(self.left, self.right)
382385

386+
def time_merge_semi(self, units, tz, monotonic):
387+
merge(self.left, self.right, how="leftsemi")
388+
383389

384390
class MergeCategoricals:
385391
def setup(self):

doc/source/user_guide/merging.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ either the left or right tables, the values in the joined table will be
407407
``right``, ``RIGHT OUTER JOIN``, Use keys from right frame only
408408
``outer``, ``FULL OUTER JOIN``, Use union of keys from both frames
409409
``inner``, ``INNER JOIN``, Use intersection of keys from both frames
410+
``leftsemi``, ``SEMIJOIN``, Filter rows on left based on occurrences in right.
410411
``cross``, ``CROSS JOIN``, Create the cartesian product of rows of both frames
411412

412413
.. ipython:: python

doc/source/whatsnew/v3.0.0.rst

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,20 @@ including other versions of pandas.
1414
Enhancements
1515
~~~~~~~~~~~~
1616

17-
.. _whatsnew_300.enhancements.enhancement1:
17+
.. _whatsnew_300.enhancements.semi_merge:
1818

19-
enhancement1
20-
^^^^^^^^^^^^
19+
New merge method ``leftsemi``
20+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
21+
22+
A new merge method ``leftsemi`` has been added to :func:`merge` and
23+
:meth:`DataFrame.merge` that returns only the rows from the left DataFrame that have
24+
a match in the right DataFrame. This is equivalent to a SQL ``LEFT SEMI JOIN``. (:issue:`42784`)
25+
26+
.. ipython:: python
27+
28+
df1 = pd.DataFrame({"key": ["A", "B", "C"], "value": [1, 2, 3]})
29+
df2 = pd.DataFrame({"key": ["A", "B"], "value": [1, 2]})
30+
df1.merge(df2, how="leftsemi")
2131
2232
.. _whatsnew_300.enhancements.enhancement2:
2333

pandas/_libs/hashtable.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ cdef class ObjectFactorizer(Factorizer):
123123
self.count, na_sentinel, na_value)
124124
self.count = len(self.uniques)
125125
return labels
126+
127+
def hash_inner_join(self, values, mask=None):
128+
return self.table.hash_inner_join(values, mask)

pandas/_libs/hashtable_class_helper.pxi.in

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,33 @@ cdef class PyObjectHashTable(HashTable):
13851385
k = kh_put_pymap(self.table, <PyObject*>val, &ret)
13861386
self.table.vals[k] = i
13871387

1388+
@cython.wraparound(False)
1389+
@cython.boundscheck(False)
1390+
def hash_inner_join(self, ndarray[object] values, object mask = None) -> tuple[ndarray, ndarray]:
1391+
cdef:
1392+
Py_ssize_t i, n = len(values)
1393+
object val
1394+
khiter_t k
1395+
Int64Vector locs = Int64Vector()
1396+
Int64Vector self_locs = Int64Vector()
1397+
Int64VectorData *l
1398+
Int64VectorData *sl
1399+
# mask not implemented
1400+
1401+
l = &locs.data
1402+
sl = &self_locs.data
1403+
1404+
for i in range(n):
1405+
val = values[i]
1406+
hash(val)
1407+
1408+
k = kh_get_pymap(self.table, <PyObject*>val)
1409+
if k != self.table.n_buckets:
1410+
append_data_int64(l, i)
1411+
append_data_int64(sl, self.table.vals[k])
1412+
1413+
return self_locs.to_array(), locs.to_array()
1414+
13881415
def lookup(self, ndarray[object] values, object mask = None) -> ndarray:
13891416
# -> np.ndarray[np.intp]
13901417
# mask not yet implemented

pandas/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def closed(self) -> bool:
447447
AnyAll = Literal["any", "all"]
448448

449449
# merge
450-
MergeHow = Literal["left", "right", "inner", "outer", "cross"]
450+
MergeHow = Literal["left", "right", "inner", "outer", "cross", "leftsemi"]
451451
MergeValidate = Literal[
452452
"one_to_one",
453453
"1:1",

pandas/core/frame.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@
315315
----------%s
316316
right : DataFrame or named Series
317317
Object to merge with.
318-
how : {'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'
318+
how : {'left', 'right', 'outer', 'inner', 'leftsemi', 'cross'}, default 'inner'
319319
Type of merge to be performed.
320320
321321
* left: use only keys from left frame, similar to a SQL left outer join;
@@ -326,6 +326,11 @@
326326
join; sort keys lexicographically.
327327
* inner: use intersection of keys from both frames, similar to a SQL inner
328328
join; preserve the order of the left keys.
329+
* leftsemi: Filter for rows in the left that have a match on the right;
330+
preserve the order of the left keys. Doesn't support `left_index`, `right_index`,
331+
`indicator` or `validate`.
332+
333+
.. versionadded:: 3.0
329334
* cross: creates the cartesian product from both frames, preserves the order
330335
of the left keys.
331336
on : label or list

pandas/core/reshape/merge.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def merge(
166166
validate=validate,
167167
)
168168
else:
169-
op = _MergeOperation(
169+
klass = _MergeOperation if how != "leftsemi" else _SemiMergeOperation
170+
op = klass(
170171
left_df,
171172
right_df,
172173
how=how,
@@ -817,7 +818,6 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
817818
# Overridden by AsOfMerge
818819
pass
819820

820-
@final
821821
def _reindex_and_concat(
822822
self,
823823
join_index: Index,
@@ -945,7 +945,6 @@ def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
945945
result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1)
946946
return result
947947

948-
@final
949948
def _maybe_restore_index_levels(self, result: DataFrame) -> None:
950949
"""
951950
Restore index levels specified as `on` parameters
@@ -989,7 +988,6 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
989988
if names_to_restore:
990989
result.set_index(names_to_restore, inplace=True)
991990

992-
@final
993991
def _maybe_add_join_keys(
994992
self,
995993
result: DataFrame,
@@ -1740,7 +1738,8 @@ def get_join_indexers(
17401738
right = Index(rkey)
17411739

17421740
if (
1743-
left.is_monotonic_increasing
1741+
how != "leftsemi"
1742+
and left.is_monotonic_increasing
17441743
and right.is_monotonic_increasing
17451744
and (left.is_unique or right.is_unique)
17461745
):
@@ -1883,6 +1882,48 @@ def _convert_to_multiindex(index: Index) -> MultiIndex:
18831882
return tuple(join_levels), tuple(join_codes), tuple(join_names)
18841883

18851884

1885+
class _SemiMergeOperation(_MergeOperation):
1886+
def __init__(self, *args, **kwargs):
1887+
if kwargs.get("validate", None):
1888+
raise NotImplementedError("validate is not supported for semi-join.")
1889+
1890+
super().__init__(*args, **kwargs)
1891+
if self.left_index or self.right_index:
1892+
raise NotImplementedError(
1893+
"left_index or right_index are not supported for semi-join."
1894+
)
1895+
elif self.indicator:
1896+
raise NotImplementedError("indicator is not supported for semi-join.")
1897+
elif self.sort:
1898+
raise NotImplementedError(
1899+
"sort is not supported for semi-join. Sort your DataFrame afterwards."
1900+
)
1901+
1902+
def _maybe_add_join_keys(
1903+
self,
1904+
result: DataFrame,
1905+
left_indexer: npt.NDArray[np.intp] | None,
1906+
right_indexer: npt.NDArray[np.intp] | None,
1907+
) -> None:
1908+
return
1909+
1910+
def _maybe_restore_index_levels(self, result: DataFrame) -> None:
1911+
return
1912+
1913+
def _reindex_and_concat(
1914+
self,
1915+
join_index: Index,
1916+
left_indexer: npt.NDArray[np.intp] | None,
1917+
right_indexer: npt.NDArray[np.intp] | None,
1918+
) -> DataFrame:
1919+
left = self.left[:]
1920+
1921+
if left_indexer is not None and not is_range_indexer(left_indexer, len(left)):
1922+
lmgr = left._mgr.take(left_indexer, axis=1, verify=False)
1923+
left = left._constructor_from_mgr(lmgr, axes=lmgr.axes)
1924+
return left
1925+
1926+
18861927
class _OrderedMerge(_MergeOperation):
18871928
_merge_type = "ordered_merge"
18881929

@@ -2470,7 +2511,7 @@ def _factorize_keys(
24702511
lk = ensure_int64(lk.codes)
24712512
rk = ensure_int64(rk.codes)
24722513

2473-
elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype:
2514+
elif how != "leftsemi" and isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype:
24742515
if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or (
24752516
isinstance(lk.dtype, StringDtype)
24762517
and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"]
@@ -2560,14 +2601,18 @@ def _factorize_keys(
25602601
lk_data, rk_data = lk, rk # type: ignore[assignment]
25612602
lk_mask, rk_mask = None, None
25622603

2563-
hash_join_available = how == "inner" and not sort and lk.dtype.kind in "iufb"
2604+
hash_join_available = how == "inner" and not sort
25642605
if hash_join_available:
25652606
rlab = rizer.factorize(rk_data, mask=rk_mask)
25662607
if rizer.get_count() == len(rlab):
25672608
ridx, lidx = rizer.hash_inner_join(lk_data, lk_mask)
25682609
return lidx, ridx, -1
25692610
else:
25702611
llab = rizer.factorize(lk_data, mask=lk_mask)
2612+
elif how == "leftsemi":
2613+
# populate hashtable for right and then do a hash join
2614+
rizer.factorize(rk_data, mask=rk_mask)
2615+
return rizer.hash_inner_join(lk_data, lk_mask)[1], None, -1
25712616
else:
25722617
llab = rizer.factorize(lk_data, mask=lk_mask)
25732618
rlab = rizer.factorize(rk_data, mask=rk_mask)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
3+
import pandas.util._test_decorators as td
4+
5+
import pandas as pd
6+
import pandas._testing as tm
7+
8+
9+
@pytest.mark.parametrize(
10+
"vals_left, vals_right",
11+
[
12+
([1, 2, 3], [1, 2]),
13+
(["a", "b", "c"], ["a", "b"]),
14+
pytest.param(
15+
pd.Series(["a", "b", "c"], dtype="string[pyarrow]"),
16+
pd.Series(["a", "b"], dtype="string[pyarrow]"),
17+
marks=td.skip_if_no("pyarrow"),
18+
),
19+
],
20+
)
21+
def test_leftsemi(vals_left, vals_right):
22+
left = pd.DataFrame({"a": vals_left, "b": [1, 2, 3]})
23+
right = pd.DataFrame({"a": vals_right, "c": 1})
24+
expected = pd.DataFrame({"a": vals_right, "b": [1, 2]})
25+
result = left.merge(right, how="leftsemi")
26+
tm.assert_frame_equal(result, expected)
27+
28+
right = pd.DataFrame({"d": vals_right, "c": 1})
29+
result = left.merge(right, how="leftsemi", left_on="a", right_on="d")
30+
tm.assert_frame_equal(result, expected)
31+
32+
right = pd.DataFrame({"d": vals_right, "c": 1})
33+
result = left.merge(right, how="leftsemi", left_on=["a", "b"], right_on=["d", "c"])
34+
tm.assert_frame_equal(result, expected.head(1))
35+
36+
37+
def test_leftsemi_invalid():
38+
left = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]})
39+
right = pd.DataFrame({"a": [1, 2], "c": 1})
40+
41+
msg = "left_index or right_index are not supported for semi-join."
42+
with pytest.raises(NotImplementedError, match=msg):
43+
left.merge(right, how="leftsemi", left_index=True, right_on="a")
44+
with pytest.raises(NotImplementedError, match=msg):
45+
left.merge(right, how="leftsemi", right_index=True, left_on="a")
46+
47+
msg = "validate is not supported for semi-join."
48+
with pytest.raises(NotImplementedError, match=msg):
49+
left.merge(right, how="leftsemi", validate="one_to_one")
50+
51+
msg = "indicator is not supported for semi-join."
52+
with pytest.raises(NotImplementedError, match=msg):
53+
left.merge(right, how="leftsemi", indicator=True)
54+
55+
msg = "sort is not supported for semi-join. Sort your DataFrame afterwards."
56+
with pytest.raises(NotImplementedError, match=msg):
57+
left.merge(right, how="leftsemi", sort=True)

0 commit comments

Comments
 (0)