|
60 | 60 | import builtins |
61 | 61 | import json |
62 | 62 |
|
63 | | -from typing import TYPE_CHECKING, Hashable, Iterable, List, Literal, Optional, Tuple, TypeVar, Union |
| 63 | +from typing import ( |
| 64 | + TYPE_CHECKING, |
| 65 | + Any, |
| 66 | + Hashable, |
| 67 | + Iterable, |
| 68 | + List, |
| 69 | + Literal, |
| 70 | + Optional, |
| 71 | + Tuple, |
| 72 | + TypeVar, |
| 73 | + Union, |
| 74 | +) |
64 | 75 | from typing import cast as type_cast |
65 | 76 |
|
66 | 77 | import numpy as np |
|
72 | 83 |
|
73 | 84 | from arkouda.numpy.dtypes import bool_ as akbool |
74 | 85 | from arkouda.numpy.dtypes import bool_scalars |
75 | | -from arkouda.numpy.dtypes import int64 as akint64 |
76 | 86 | from arkouda.numpy.manipulation_functions import flip as ak_flip |
77 | 87 | from arkouda.numpy.pdarrayclass import RegistrationError, pdarray |
78 | 88 | from arkouda.numpy.pdarraysetops import argsort, in1d |
79 | 89 | from arkouda.numpy.sorting import coargsort |
80 | 90 | from arkouda.numpy.util import convert_if_categorical, generic_concat, get_callback |
81 | | -from arkouda.pandas.groupbyclass import GroupBy, unique |
| 91 | +from arkouda.pandas.groupbyclass import GroupBy, groupable, unique |
82 | 92 |
|
83 | 93 |
|
84 | 94 | __all__ = [ |
@@ -1202,12 +1212,13 @@ def lookup(self, key): |
1202 | 1212 | Returns |
1203 | 1213 | ------- |
1204 | 1214 | pdarray |
1205 | | - A boolean array indicating which elements of `key` are present in the Index. |
| 1215 | + A boolean array of length ``len(self)``, indicating which entries of |
| 1216 | + the Index are present in `key`. |
1206 | 1217 |
|
1207 | 1218 | Raises |
1208 | 1219 | ------ |
1209 | 1220 | TypeError |
1210 | | - If `key` is not a scalar or a pdarray. |
| 1221 | + If `key` cannot be converted to an arkouda array. |
1211 | 1222 |
|
1212 | 1223 | """ |
1213 | 1224 | from arkouda.numpy.pdarrayclass import pdarray |
@@ -2139,39 +2150,77 @@ def concat(self, other): |
2139 | 2150 | idx = [generic_concat([ix1, ix2], ordered=True) for ix1, ix2 in zip(self.index, other.index)] |
2140 | 2151 | return MultiIndex(idx) |
2141 | 2152 |
|
2142 | | - def lookup(self, key): |
| 2153 | + def lookup(self, key: list[Any] | tuple[Any, ...]) -> groupable: |
2143 | 2154 | """ |
2144 | 2155 | Perform element-wise lookup on the MultiIndex. |
2145 | 2156 |
|
2146 | 2157 | Parameters |
2147 | 2158 | ---------- |
2148 | 2159 | key : list or tuple |
2149 | | - A sequence of values, one for each level of the MultiIndex. Values may be scalars |
2150 | | - or pdarrays. If scalars, they are cast to the appropriate Arkouda array type. |
| 2160 | + A sequence of values, one for each level of the MultiIndex. |
| 2161 | +
|
| 2162 | + - If the elements are scalars (e.g., ``(1, "red")``), they are |
| 2163 | + treated as a single row key: the result is a boolean mask over |
| 2164 | + rows where all levels match the corresponding scalar. |
| 2165 | + - If the elements are arkouda arrays (e.g., list of pdarrays / |
| 2166 | + Strings), they must align one-to-one with the levels, and the |
| 2167 | + lookup is delegated to ``in1d(self.index, key)`` for multi-column |
| 2168 | + membership. |
2151 | 2169 |
|
2152 | 2170 | Returns |
2153 | 2171 | ------- |
2154 | | - pdarray |
| 2172 | + groupable |
2155 | 2173 | A boolean array indicating which rows in the MultiIndex match the key. |
2156 | 2174 |
|
2157 | 2175 | Raises |
2158 | 2176 | ------ |
2159 | 2177 | TypeError |
2160 | | - If `key` is not a list or tuple, or if its elements cannot be converted to pdarrays. |
2161 | | -
|
| 2178 | + If `key` is not a list or tuple. |
| 2179 | + ValueError |
| 2180 | + If the length of `key` does not match the number of levels. |
2162 | 2181 | """ |
2163 | | - from arkouda.numpy import cast as akcast |
2164 | | - from arkouda.numpy.pdarrayclass import pdarray |
2165 | | - from arkouda.numpy.pdarraycreation import array |
| 2182 | + from arkouda.numpy.pdarraycreation import array as ak_array |
| 2183 | + from arkouda.numpy.strings import Strings |
| 2184 | + |
| 2185 | + if not isinstance(key, (list, tuple)): |
| 2186 | + types = [type(k).__name__ for k in key] |
| 2187 | + raise TypeError( |
| 2188 | + f"MultiIndex.lookup expects a list or tuple of keys, one per level. Received {types}." |
| 2189 | + ) |
| 2190 | + |
| 2191 | + if len(key) != self.nlevels: |
| 2192 | + raise ValueError( |
| 2193 | + f"MultiIndex.lookup key length {len(key)} must match number of levels {self.nlevels}" |
| 2194 | + ) |
| 2195 | + |
| 2196 | + # Case 1: user passed per-level arkouda arrays. |
| 2197 | + # We assume they are already the correct types and lengths. |
| 2198 | + is_array_mode = all(isinstance(k, (pdarray, Strings)) for k in key) |
| 2199 | + if is_array_mode: |
| 2200 | + return in1d(self.index, key) |
| 2201 | + |
| 2202 | + # Don't allow mixed scalar/array keys. |
| 2203 | + is_any_array = any(isinstance(k, (pdarray, Strings)) for k in key) |
| 2204 | + if is_any_array and not is_array_mode: |
| 2205 | + raise TypeError( |
| 2206 | + "MultiIndex.lookup key must be all scalars (row key) or all arkouda arrays " |
| 2207 | + "(per-level membership). " |
| 2208 | + f"Received mixed types: {[type(k) for k in key]}" |
| 2209 | + ) |
| 2210 | + |
| 2211 | + # Case 2: user passed scalars (e.g., (1, "red")). |
| 2212 | + # Convert each scalar to a length-1 arkouda array, preserving per-level dtypes. |
| 2213 | + scalar_key_arrays = [] |
| 2214 | + for i, v in enumerate(key): |
| 2215 | + lvl = self.levels[i] |
| 2216 | + |
| 2217 | + # Determine the dtype for this level |
| 2218 | + dt = lvl.dtype |
2166 | 2219 |
|
2167 | | - if not isinstance(key, list) and not isinstance(key, tuple): |
2168 | | - raise TypeError("MultiIndex lookup failure") |
2169 | | - # if individual vals convert to pdarrays |
2170 | | - if not isinstance(key[0], pdarray): |
2171 | | - dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64 |
2172 | | - key = [akcast(array([x]), dt) for x in key] |
| 2220 | + a = ak_array([v], dtype=dt) # make length-1 array |
| 2221 | + scalar_key_arrays.append(a) |
2173 | 2222 |
|
2174 | | - return in1d(self.index, key) |
| 2223 | + return in1d(self.index, scalar_key_arrays) |
2175 | 2224 |
|
2176 | 2225 | def to_hdf( |
2177 | 2226 | self, |
|
0 commit comments