Skip to content

Commit 634093f

Browse files
authored
[ENH] Minor fix for conditional_join (#1207)
1 parent 3fe9f9c commit 634093f

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

janitor/functions/_numba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _numba_pair_le_lt(df: pd.DataFrame, right: pd.DataFrame, pair: list):
163163
# 6 has no match in pair2 of value_2A/2B, so we discard
164164
# our final matching indices for the left and right pairs
165165
#########################################################
166-
# left_index right_indes
166+
# left_index right_index
167167
# 0 7
168168
# 4 5
169169
# 5 1
@@ -261,6 +261,9 @@ def _realign(indices, regions):
261261
# this function ensures the regions are properly aligned
262262
arr1, arr2 = indices
263263
region1, region2 = regions
264+
# arr2 is used as the reference point
265+
# because we are certain that at the very least
266+
# it has the same items as arr1, but not more
264267
indexer = pd.Index(arr2).get_indexer(arr1)
265268
mask = indexer == -1
266269
if mask.any():
@@ -724,7 +727,7 @@ def _get_regions(
724727
# are present ---> l1 < r1 & l2 > r2
725728
# For two non equi conditions, the matches are where
726729
# the regions from group A (l1 < r1)
727-
# are also lower than the regions from group B (l2 > r2)
730+
# are also lower than the regions from group B (l2 < r2)
728731
# This implementation is based on the algorithm outlined here:
729732
# https://www.scitepress.org/papers/2018/68268/68268.pdf
730733
indices = _search_indices(left_c, right_c, strict, op_code)

janitor/functions/conditional_join.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,10 @@ def _conditional_join_type_check(
354354
f"'{right_column.name}' has {right_column.dtype} type."
355355
)
356356

357-
if (op in less_than_join_types.union(greater_than_join_types)) & (
358-
(is_string_dtype(left_column) | is_categorical_dtype(left_column))
359-
):
357+
number_or_date = is_numeric_dtype(left_column) or is_datetime64_dtype(
358+
left_column
359+
)
360+
if (op != _JoinOperator.STRICTLY_EQUAL.value) & (not number_or_date):
360361
raise ValueError(
361362
"non-equi joins are supported "
362363
"only for datetime and numeric dtypes. "
@@ -490,12 +491,12 @@ def _less_than_indices(
490491
if left.min() > right.max():
491492
return None
492493

493-
any_nulls = pd.isna(left)
494+
any_nulls = left.isna()
494495
if any_nulls.all():
495496
return None
496497
if any_nulls.any():
497498
left = left[~any_nulls]
498-
any_nulls = pd.isna(right)
499+
any_nulls = right.isna()
499500
if any_nulls.all():
500501
return None
501502
if any_nulls.any():
@@ -597,12 +598,12 @@ def _greater_than_indices(
597598
if left.max() < right.min():
598599
return None
599600

600-
any_nulls = pd.isna(left)
601+
any_nulls = left.isna()
601602
if any_nulls.all():
602603
return None
603604
if any_nulls.any():
604605
left = left[~any_nulls]
605-
any_nulls = pd.isna(right)
606+
any_nulls = right.isna()
606607
if any_nulls.all():
607608
return None
608609
if any_nulls.any():
@@ -1129,10 +1130,10 @@ def _range_indices(
11291130
# get rid of any nulls
11301131
# this is helpful as we can convert extension arrays to numpy arrays safely
11311132
# and simplify the search logic below
1132-
any_nulls = pd.isna(df[left_on])
1133+
any_nulls = df[left_on].isna()
11331134
if any_nulls.any():
11341135
left_c = left_c[~any_nulls]
1135-
any_nulls = pd.isna(right[right_on])
1136+
any_nulls = right[right_on].isna()
11361137
if any_nulls.any():
11371138
right_c = right_c[~any_nulls]
11381139

@@ -1160,16 +1161,26 @@ def _range_indices(
11601161
right_c = right_c._values
11611162
left_c, right_c = _convert_to_numpy_array(left_c, right_c)
11621163
op = operator_map[op]
1163-
pos = np.empty(left_c.size, dtype=np.intp)
1164-
1165-
# better served in a compiled environment
1166-
# where we can break early
1167-
# parallelise the operation, as well as
1168-
# avoid the restrictive fixed size approach of numpy
1169-
# which isnt particularly helpful in a for loop
1170-
for ind in range(left_c.size):
1171-
out = op(left_c[ind], right_c)
1172-
pos[ind] = np.argmax(out)
1164+
pos = np.copy(search_indices)
1165+
counter = np.arange(left_c.size)
1166+
1167+
# better than np.outer memory wise?
1168+
# using this for loop instead of np.outer
1169+
# allows us to break early and reduce the
1170+
# number of cartesian checks
1171+
# since as we iterate, we reduce the size of left_c
1172+
# speed wise, np.outer will be faster
1173+
# alternatively, the user can just use the numba option
1174+
# for more performance
1175+
for ind in range(right_c.size):
1176+
if not counter.size:
1177+
break
1178+
keep_rows = op(left_c, right_c[ind])
1179+
if not keep_rows.any():
1180+
continue
1181+
pos[counter[keep_rows]] = ind
1182+
counter = counter[~keep_rows]
1183+
left_c = left_c[~keep_rows]
11731184

11741185
# no point searching within (a, b)
11751186
# if a == b
@@ -1261,10 +1272,10 @@ def _create_frame(
12611272
"""
12621273
Create final dataframe
12631274
"""
1264-
if df_columns:
1275+
if df_columns is not None:
12651276
df = _cond_join_select_columns(df_columns, df)
12661277

1267-
if right_columns:
1278+
if right_columns is not None:
12681279
right = _cond_join_select_columns(right_columns, right)
12691280

12701281
if set(df.columns).intersection(right.columns):

0 commit comments

Comments
 (0)