Skip to content

Commit 086ce45

Browse files
committed
Closes #5155: Bug: MultiIndex .lookup() attempts illegal dtype cast for tuple keys
1 parent 538fb39 commit 086ce45

File tree

2 files changed

+86
-21
lines changed

2 files changed

+86
-21
lines changed

arkouda/pandas/index.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,18 @@
6060
import builtins
6161
import json
6262

63-
from typing import TYPE_CHECKING, Hashable, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
63+
from typing import (
64+
TYPE_CHECKING,
65+
Any,
66+
Hashable,
67+
Iterable,
68+
List,
69+
Literal,
70+
Optional,
71+
Tuple,
72+
TypeVar,
73+
Union,
74+
)
6475
from typing import cast as type_cast
6576

6677
import numpy as np
@@ -72,13 +83,12 @@
7283

7384
from arkouda.numpy.dtypes import bool_ as akbool
7485
from arkouda.numpy.dtypes import bool_scalars
75-
from arkouda.numpy.dtypes import int64 as akint64
7686
from arkouda.numpy.manipulation_functions import flip as ak_flip
7787
from arkouda.numpy.pdarrayclass import RegistrationError, pdarray
7888
from arkouda.numpy.pdarraysetops import argsort, in1d
7989
from arkouda.numpy.sorting import coargsort
8090
from arkouda.numpy.util import convert_if_categorical, generic_concat, get_callback
81-
from arkouda.pandas.groupbyclass import GroupBy, unique
91+
from arkouda.pandas.groupbyclass import GroupBy, groupable, unique
8292

8393

8494
__all__ = [
@@ -1202,12 +1212,13 @@ def lookup(self, key):
12021212
Returns
12031213
-------
12041214
pdarray
1205-
A boolean array indicating which elements of `key` are present in the Index.
1215+
A boolean array of length ``len(self)``, indicating which entries of
1216+
the Index are present in `key`.
12061217
12071218
Raises
12081219
------
12091220
TypeError
1210-
If `key` is not a scalar or a pdarray.
1221+
If `key` cannot be converted to an arkouda array.
12111222
12121223
"""
12131224
from arkouda.numpy.pdarrayclass import pdarray
@@ -2139,39 +2150,77 @@ def concat(self, other):
21392150
idx = [generic_concat([ix1, ix2], ordered=True) for ix1, ix2 in zip(self.index, other.index)]
21402151
return MultiIndex(idx)
21412152

2142-
def lookup(self, key):
2153+
def lookup(self, key: list[Any] | tuple[Any, ...]) -> groupable:
21432154
"""
21442155
Perform element-wise lookup on the MultiIndex.
21452156
21462157
Parameters
21472158
----------
21482159
key : list or tuple
2149-
A sequence of values, one for each level of the MultiIndex. Values may be scalars
2150-
or pdarrays. If scalars, they are cast to the appropriate Arkouda array type.
2160+
A sequence of values, one for each level of the MultiIndex.
2161+
2162+
- If the elements are scalars (e.g., ``(1, "red")``), they are
2163+
treated as a single row key: the result is a boolean mask over
2164+
rows where all levels match the corresponding scalar.
2165+
- If the elements are arkouda arrays (e.g., list of pdarrays /
2166+
Strings), they must align one-to-one with the levels, and the
2167+
lookup is delegated to ``in1d(self.index, key)`` for multi-column
2168+
membership.
21512169
21522170
Returns
21532171
-------
2154-
pdarray
2172+
groupable
21552173
A boolean array indicating which rows in the MultiIndex match the key.
21562174
21572175
Raises
21582176
------
21592177
TypeError
2160-
If `key` is not a list or tuple, or if its elements cannot be converted to pdarrays.
2161-
2178+
If `key` is not a list or tuple.
2179+
ValueError
2180+
If the length of `key` does not match the number of levels.
21622181
"""
2163-
from arkouda.numpy import cast as akcast
2164-
from arkouda.numpy.pdarrayclass import pdarray
2165-
from arkouda.numpy.pdarraycreation import array
2182+
from arkouda.numpy.pdarraycreation import array as ak_array
2183+
from arkouda.numpy.strings import Strings
2184+
2185+
if not isinstance(key, (list, tuple)):
2186+
types = [type(k).__name__ for k in key]
2187+
raise TypeError(
2188+
f"MultiIndex.lookup expects a list or tuple of keys, one per level. Received {types}."
2189+
)
2190+
2191+
if len(key) != self.nlevels:
2192+
raise ValueError(
2193+
f"MultiIndex.lookup key length {len(key)} must match number of levels {self.nlevels}"
2194+
)
2195+
2196+
# Case 1: user passed per-level arkouda arrays.
2197+
# We assume they are already the correct types and lengths.
2198+
is_array_mode = all(isinstance(k, (pdarray, Strings)) for k in key)
2199+
if is_array_mode:
2200+
return in1d(self.index, key)
2201+
2202+
# Don't allow mixed scalar/array keys.
2203+
is_any_array = any(isinstance(k, (pdarray, Strings)) for k in key)
2204+
if is_any_array and not is_array_mode:
2205+
raise TypeError(
2206+
"MultiIndex.lookup key must be all scalars (row key) or all arkouda arrays "
2207+
"(per-level membership). "
2208+
f"Received mixed types: {[type(k) for k in key]}"
2209+
)
2210+
2211+
# Case 2: user passed scalars (e.g., (1, "red")).
2212+
# Convert each scalar to a length-1 arkouda array, preserving per-level dtypes.
2213+
scalar_key_arrays = []
2214+
for i, v in enumerate(key):
2215+
lvl = self.levels[i]
2216+
2217+
# Determine the dtype for this level
2218+
dt = lvl.dtype
21662219

2167-
if not isinstance(key, list) and not isinstance(key, tuple):
2168-
raise TypeError("MultiIndex lookup failure")
2169-
# if individual vals convert to pdarrays
2170-
if not isinstance(key[0], pdarray):
2171-
dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64
2172-
key = [akcast(array([x]), dt) for x in key]
2220+
a = ak_array([v], dtype=dt) # make length-1 array
2221+
scalar_key_arrays.append(a)
21732222

2174-
return in1d(self.index, key)
2223+
return in1d(self.index, scalar_key_arrays)
21752224

21762225
def to_hdf(
21772226
self,

tests/pandas/index_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,3 +846,19 @@ def test_round_trip_conversion_categorical(self, size):
846846
i1 = ak.Index(ak.Categorical(strings1))
847847
i2 = ak.Index(i1.to_pandas())
848848
assert_index_equal(i1, i2)
849+
850+
def test_multiindex_lookup_tuple_mixed_dtypes(self):
851+
# Level 0: int
852+
lvl0 = ak.array([1, 1, 2, 3])
853+
# Level 1: strings
854+
lvl1 = ak.array(["red", "blue", "red", "red"])
855+
856+
midx = ak.MultiIndex([lvl0, lvl1], names=["num", "color"])
857+
858+
# Tuple key mixes int + str and should NOT trigger castStringsTo<int64> on "red"
859+
mask = midx.lookup((1, "red"))
860+
861+
assert mask.dtype == ak.bool_
862+
863+
# Expect exactly the first row to match
864+
assert mask.to_ndarray().tolist() == [True, False, False, False]

0 commit comments

Comments
 (0)