diff --git a/arkouda/pandas/index.py b/arkouda/pandas/index.py index e60ae087fea..c0cfe2b457f 100644 --- a/arkouda/pandas/index.py +++ b/arkouda/pandas/index.py @@ -72,13 +72,12 @@ from arkouda.numpy.dtypes import bool_ as akbool from arkouda.numpy.dtypes import bool_scalars -from arkouda.numpy.dtypes import int64 as akint64 from arkouda.numpy.manipulation_functions import flip as ak_flip from arkouda.numpy.pdarrayclass import RegistrationError, pdarray from arkouda.numpy.pdarraysetops import argsort, in1d from arkouda.numpy.sorting import coargsort from arkouda.numpy.util import convert_if_categorical, generic_concat, get_callback -from arkouda.pandas.groupbyclass import GroupBy, unique +from arkouda.pandas.groupbyclass import GroupBy, groupable, unique __all__ = [ @@ -1202,12 +1201,13 @@ def lookup(self, key): Returns ------- pdarray - A boolean array indicating which elements of `key` are present in the Index. + A boolean array of length ``len(self)``, indicating which entries of + the Index are present in `key`. Raises ------ TypeError - If `key` is not a scalar or a pdarray. + If `key` cannot be converted to an arkouda array. """ from arkouda.numpy.pdarrayclass import pdarray @@ -2139,15 +2139,22 @@ def concat(self, other): idx = [generic_concat([ix1, ix2], ordered=True) for ix1, ix2 in zip(self.index, other.index)] return MultiIndex(idx) - def lookup(self, key): + def lookup(self, key: Union[List, Tuple]) -> groupable: """ Perform element-wise lookup on the MultiIndex. Parameters ---------- key : list or tuple - A sequence of values, one for each level of the MultiIndex. Values may be scalars - or pdarrays. If scalars, they are cast to the appropriate Arkouda array type. + A sequence of values, one for each level of the MultiIndex. + + - If the elements are scalars (e.g., ``(1, "red")``), they are + treated as a single row key: the result is a boolean mask over + rows where all levels match the corresponding scalar. + - If the elements are arkouda arrays (e.g., list of pdarrays / + Strings), they must align one-to-one with the levels, and the + lookup is delegated to ``in1d(self.index, key)`` for multi-column + membership. Returns ------- @@ -2157,21 +2164,41 @@ def lookup(self, key): Raises ------ TypeError - If `key` is not a list or tuple, or if its elements cannot be converted to pdarrays. + If `key` is not a list or tuple. + ValueError + If the length of `key` does not match the number of levels. """ - from arkouda.numpy import cast as akcast from arkouda.numpy.pdarrayclass import pdarray from arkouda.numpy.pdarraycreation import array + from arkouda.numpy.strings import Strings + + if not isinstance(key, (list, tuple)): + raise TypeError("MultiIndex.lookup expects a list or tuple of keys, one per level") + + if len(key) != self.nlevels: + raise ValueError( + f"MultiIndex.lookup key length {len(key)} must match number of levels {self.nlevels}" + ) + + # Case 1: user passed per-level arkouda arrays. + # We assume they are already the correct types and lengths. + if isinstance(key[0], (pdarray, Strings)): + return in1d(self.index, key) + + # Case 2: user passed scalars (e.g., (1, "red")). + # Convert each scalar to a length-1 arkouda array, preserving per-level dtypes. + scalar_key_arrays = [] + for i, v in enumerate(key): + lvl = self.levels[i] + + # Determine the dtype for this level + dt = lvl.dtype - if not isinstance(key, list) and not isinstance(key, tuple): - raise TypeError("MultiIndex lookup failure") - # if individual vals convert to pdarrays - if not isinstance(key[0], pdarray): - dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64 - key = [akcast(array([x]), dt) for x in key] + a = array([v], dtype=dt) # make length-1 array + scalar_key_arrays.append(a) - return in1d(self.index, key) + return in1d(self.index, scalar_key_arrays) def to_hdf( self, diff --git a/tests/pandas/index_test.py b/tests/pandas/index_test.py index aaa6a65f584..ff84a9f657c 100644 --- a/tests/pandas/index_test.py +++ b/tests/pandas/index_test.py @@ -846,3 +846,17 @@ def test_round_trip_conversion_categorical(self, size): i1 = ak.Index(ak.Categorical(strings1)) i2 = ak.Index(i1.to_pandas()) assert_index_equal(i1, i2) + + def test_multiindex_lookup_tuple_mixed_dtypes(self): + # Level 0: int + lvl0 = ak.array([1, 1, 2, 3]) + # Level 1: strings + lvl1 = ak.array(["red", "blue", "red", "red"]) + + midx = ak.MultiIndex([lvl0, lvl1], names=["num", "color"]) + + # Tuple key mixes int + str and should NOT trigger castStringsTo on "red" + mask = midx.lookup((1, "red")) + + # Expect exactly the first row to match + assert mask.to_ndarray().tolist() == [True, False, False, False]