Skip to content

Commit 0c0232a

Browse files
committed
Closes #5155: Bug: MultiIndex .lookup() attempts illegal dtype cast for tuple keys
1 parent 695b732 commit 0c0232a

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

arkouda/pandas/index.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272

7373
from arkouda.numpy.dtypes import bool_ as akbool
7474
from arkouda.numpy.dtypes import bool_scalars
75-
from arkouda.numpy.dtypes import int64 as akint64
7675
from arkouda.numpy.manipulation_functions import flip as ak_flip
7776
from arkouda.numpy.pdarrayclass import RegistrationError, create_pdarray, pdarray
7877
from arkouda.numpy.pdarraycreation import array, ones
@@ -1192,12 +1191,13 @@ def lookup(self, key):
11921191
Returns
11931192
-------
11941193
pdarray
1195-
A boolean array indicating which elements of `key` are present in the Index.
1194+
A boolean array of length ``len(self)``, indicating which entries of
1195+
the Index are present in `key`.
11961196
11971197
Raises
11981198
------
11991199
TypeError
1200-
If `key` is not a scalar or a pdarray.
1200+
If `key` cannot be converted to an arkouda array.
12011201
12021202
"""
12031203
if not isinstance(key, pdarray):
@@ -2133,8 +2133,15 @@ def lookup(self, key):
21332133
Parameters
21342134
----------
21352135
key : list or tuple
2136-
A sequence of values, one for each level of the MultiIndex. Values may be scalars
2137-
or pdarrays. If scalars, they are cast to the appropriate Arkouda array type.
2136+
A sequence of values, one for each level of the MultiIndex.
2137+
2138+
- If the elements are scalars (e.g., ``(1, "red")``), they are
2139+
treated as a single row key: the result is a boolean mask over
2140+
rows where all levels match the corresponding scalar.
2141+
- If the elements are arkouda arrays (e.g., list of pdarrays /
2142+
Strings), they must align one-to-one with the levels, and the
2143+
lookup is delegated to ``in1d(self.index, key)`` for multi-column
2144+
membership.
21382145
21392146
Returns
21402147
-------
@@ -2144,19 +2151,29 @@ def lookup(self, key):
21442151
Raises
21452152
------
21462153
TypeError
2147-
If `key` is not a list or tuple, or if its elements cannot be converted to pdarrays.
2154+
If `key` is not a list or tuple.
2155+
ValueError
2156+
If the length of `key` does not match the number of levels.
21482157
21492158
"""
2150-
from arkouda.numpy import cast as akcast
2159+
if not isinstance(key, (list, tuple)):
2160+
raise TypeError("MultiIndex.lookup expects a list or tuple of keys, one per level")
2161+
2162+
if len(key) != self.nlevels:
2163+
raise ValueError(
2164+
f"MultiIndex.lookup key length {len(key)} must match number of levels {self.nlevels}"
2165+
)
2166+
2167+
# Case 1: user passed per-level arkouda arrays.
2168+
# We assume they are already the correct types and lengths.
2169+
if isinstance(key[0], (pdarray, Strings)):
2170+
return in1d(self.index, key)
21512171

2152-
if not isinstance(key, list) and not isinstance(key, tuple):
2153-
raise TypeError("MultiIndex lookup failure")
2154-
# if individual vals convert to pdarrays
2155-
if not isinstance(key[0], pdarray):
2156-
dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64
2157-
key = [akcast(array([x]), dt) for x in key]
2172+
# Case 2: user passed scalars (e.g., (1, "red")).
2173+
# Convert each scalar to a length-1 arkouda array, preserving per-level dtypes.
2174+
scalar_key_arrays = [array([v]) for v in key]
21582175

2159-
return in1d(self.index, key)
2176+
return in1d(self.index, scalar_key_arrays)
21602177

21612178
def to_hdf(
21622179
self,

tests/pandas/index_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,17 @@ def test_round_trip_conversion_categorical(self, size):
843843
i1 = ak.Index(ak.Categorical(strings1))
844844
i2 = ak.Index(i1.to_pandas())
845845
assert_index_equal(i1, i2)
846+
847+
def test_multiindex_lookup_tuple_mixed_dtypes(self):
848+
# Level 0: int
849+
lvl0 = ak.array([1, 1, 2, 3])
850+
# Level 1: strings
851+
lvl1 = ak.array(["red", "blue", "red", "red"])
852+
853+
midx = ak.MultiIndex([lvl0, lvl1], names=["num", "color"])
854+
855+
# Tuple key mixes int + str and should NOT trigger castStringsTo<int64> on "red"
856+
mask = midx.lookup((1, "red"))
857+
858+
# Expect exactly the first row to match
859+
assert mask.to_ndarray().tolist() == [True, False, False, False]

0 commit comments

Comments
 (0)