|
11 | 11 | ExtensionDtype, |
12 | 12 | register_extension_dtype, |
13 | 13 | ) |
14 | | -from pandas.core.dtypes.common import is_string_dtype |
| 14 | +from pandas.core.dtypes.common import ( |
| 15 | + is_bool_dtype, |
| 16 | + is_integer_dtype, |
| 17 | + is_string_dtype, |
| 18 | +) |
15 | 19 | from pandas.core.dtypes.dtypes import ArrowDtype |
16 | 20 |
|
17 | 21 | from pandas.core.arrays.arrow.array import ArrowExtensionArray |
| 22 | +from pandas.core.arrays.base import ExtensionArray |
18 | 23 |
|
19 | 24 | if TYPE_CHECKING: |
20 | 25 | from collections.abc import Sequence |
@@ -146,6 +151,15 @@ def __init__( |
146 | 151 | else: |
147 | 152 | if value_type is None: |
148 | 153 | if isinstance(values, (pa.Array, pa.ChunkedArray)): |
| 154 | + parent_type = values.type |
| 155 | + if not isinstance(parent_type, (pa.ListType, pa.LargeListType)): |
| 156 | + # Ideally could cast here, but I don't think pyarrow implements |
| 157 | + # many list casts |
| 158 | + new_values = [ |
| 159 | + [x.as_py()] if x.is_valid else None for x in values |
| 160 | + ] |
| 161 | + values = pa.array(new_values, type=pa.large_list(parent_type)) |
| 162 | + |
149 | 163 | value_type = values.type.value_type |
150 | 164 | else: |
151 | 165 | value_type = pa.array(values).type.value_type |
@@ -193,19 +207,89 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False): |
193 | 207 |
|
194 | 208 | return cls(values) |
195 | 209 |
|
| 210 | + @classmethod |
| 211 | + def _box_pa( |
| 212 | + cls, value, pa_type: pa.DataType | None = None |
| 213 | + ) -> pa.Array | pa.ChunkedArray | pa.Scalar: |
| 214 | + """ |
| 215 | + Box value into a pyarrow Array, ChunkedArray or Scalar. |
| 216 | +
|
| 217 | + Parameters |
| 218 | + ---------- |
| 219 | + value : any |
| 220 | + pa_type : pa.DataType | None |
| 221 | +
|
| 222 | + Returns |
| 223 | + ------- |
| 224 | + pa.Array or pa.ChunkedArray or pa.Scalar |
| 225 | + """ |
| 226 | + if ( |
| 227 | + isinstance(value, (pa.ListScalar, pa.LargeListScalar)) |
| 228 | + or isinstance(value, list) |
| 229 | + or value is None |
| 230 | + ): |
| 231 | + return cls._box_pa_scalar(value, pa_type) |
| 232 | + return cls._box_pa_array(value, pa_type) |
| 233 | + |
196 | 234 | def __getitem__(self, item): |
197 | 235 | # PyArrow does not support NumPy's selection with an equal length |
198 | 236 | # mask, so let's convert those to integral positions if needed |
199 | | - if isinstance(item, np.ndarray) and item.dtype == bool: |
200 | | - pos = np.array(range(len(item))) |
201 | | - mask = pos[item] |
202 | | - return type(self)(self._pa_array.take(mask)) |
| 237 | + if isinstance(item, (np.ndarray, ExtensionArray)): |
| 238 | + if is_bool_dtype(item.dtype): |
| 239 | + mask_len = len(item) |
| 240 | + if mask_len != len(self): |
| 241 | + raise IndexError( |
| 242 | + f"Boolean index has wrong length: {mask_len} " |
| 243 | + f"instead of {len(self)}" |
| 244 | + ) |
| 245 | + pos = np.array(range(len(item))) |
| 246 | + |
| 247 | + if isinstance(item, ExtensionArray): |
| 248 | + mask = pos[item.fillna(False)] |
| 249 | + else: |
| 250 | + mask = pos[item] |
| 251 | + return type(self)(self._pa_array.take(mask)) |
| 252 | + elif is_integer_dtype(item.dtype): |
| 253 | + if isinstance(item, ExtensionArray) and item.isna().any(): |
| 254 | + msg = "Cannot index with an integer indexer containing NA values" |
| 255 | + raise ValueError(msg) |
| 256 | + |
| 257 | + indexer = pa.array(item) |
| 258 | + return type(self)(self._pa_array.take(indexer)) |
203 | 259 | elif isinstance(item, int): |
204 | | - return self._pa_array[item] |
| 260 | + value = self._pa_array[item] |
| 261 | + if value.is_valid: |
| 262 | + return value.as_py() |
| 263 | + else: |
| 264 | + return self.dtype.na_value |
205 | 265 | elif isinstance(item, list): |
206 | | - return type(self)(self._pa_array.take(item)) |
| 266 | + # pyarrow does not support taking yet from an empty list |
| 267 | + # https://github.com/apache/arrow/issues/39917 |
| 268 | + if item: |
| 269 | + try: |
| 270 | + result = self._pa_array.take(item) |
| 271 | + except pa.lib.ArrowInvalid as e: |
| 272 | + if "Could not convert <NA>" in str(e): |
| 273 | + msg = ( |
| 274 | + "Cannot index with an integer indexer containing NA values" |
| 275 | + ) |
| 276 | + raise ValueError(msg) from e |
| 277 | + raise e |
| 278 | + else: |
| 279 | + result = pa.array([], type=self._pa_array.type) |
| 280 | + |
| 281 | + return type(self)(result) |
| 282 | + |
| 283 | + try: |
| 284 | + result = type(self)(self._pa_array[item]) |
| 285 | + except TypeError as e: |
| 286 | + msg = ( |
| 287 | + "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis " |
| 288 | + "(`None`) and integer or boolean arrays are valid indices" |
| 289 | + ) |
| 290 | + raise IndexError(msg) from e |
207 | 291 |
|
208 | | - return type(self)(self._pa_array[item]) |
| 292 | + return result |
209 | 293 |
|
210 | 294 | def __setitem__(self, key, value) -> None: |
211 | 295 | msg = "ListArray does not support item assignment via setitem" |
@@ -241,7 +325,13 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: |
241 | 325 | return super().astype(dtype, copy) |
242 | 326 |
|
243 | 327 | def __eq__(self, other): |
244 | | - if isinstance(other, (pa.ListScalar, pa.LargeListScalar)): |
| 328 | + if isinstance(other, list): |
| 329 | + from pandas.arrays import BooleanArray |
| 330 | + |
| 331 | + mask = np.array([False] * len(self)) |
| 332 | + values = np.array([x.as_py() == other for x in self._pa_array]) |
| 333 | + return BooleanArray(values, mask) |
| 334 | + elif isinstance(other, (pa.ListScalar, pa.LargeListScalar)): |
245 | 335 | from pandas.arrays import BooleanArray |
246 | 336 |
|
247 | 337 | # TODO: pyarrow.compute does not implement broadcasting equality |
|
0 commit comments