|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import ( |
| 4 | + TYPE_CHECKING, |
| 5 | + ClassVar, |
| 6 | +) |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from pandas._libs import missing as libmissing |
| 11 | +from pandas.compat import HAS_PYARROW |
| 12 | +from pandas.util._decorators import set_module |
| 13 | + |
| 14 | +from pandas.core.dtypes.base import ( |
| 15 | + ExtensionDtype, |
| 16 | + register_extension_dtype, |
| 17 | +) |
| 18 | +from pandas.core.dtypes.common import ( |
| 19 | + is_object_dtype, |
| 20 | + is_string_dtype, |
| 21 | +) |
| 22 | + |
| 23 | +from pandas.core.arrays import ExtensionArray |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from pandas._typing import type_t |
| 27 | + |
| 28 | +import pyarrow as pa |
| 29 | + |
| 30 | + |
| 31 | +@register_extension_dtype |
| 32 | +@set_module("pandas") |
| 33 | +class ListDtype(ExtensionDtype): |
| 34 | + """ |
| 35 | + An ExtensionDtype suitable for storing homogeneous lists of data. |
| 36 | + """ |
| 37 | + |
| 38 | + type = list |
| 39 | + name: ClassVar[str] = "list" |
| 40 | + |
| 41 | + @property |
| 42 | + def na_value(self) -> libmissing.NAType: |
| 43 | + return libmissing.NA |
| 44 | + |
| 45 | + @property |
| 46 | + def kind(self) -> str: |
| 47 | + # TODO: our extension interface says this field should be the |
| 48 | + # NumPy type character, but no such thing exists for list |
| 49 | + # this assumes a PyArrow large list |
| 50 | + return "+L" |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def construct_array_type(cls) -> type_t[ListArray]: |
| 54 | + """ |
| 55 | + Return the array type associated with this dtype. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + type |
| 60 | + """ |
| 61 | + return ListArray |
| 62 | + |
| 63 | + |
| 64 | +class ListArray(ExtensionArray): |
| 65 | + dtype = ListDtype() |
| 66 | + __array_priority__ = 1000 |
| 67 | + |
| 68 | + def __init__(self, values: pa.Array | pa.ChunkedArray | list | ListArray) -> None: |
| 69 | + if not HAS_PYARROW: |
| 70 | + raise NotImplementedError("ListArray requires pyarrow to be installed") |
| 71 | + |
| 72 | + if isinstance(values, type(self)): |
| 73 | + self._pa_array = values._pa_array |
| 74 | + elif not isinstance(values, pa.ChunkedArray): |
| 75 | + # To support NA, we need to create an Array first :-( |
| 76 | + arr = pa.array(values, from_pandas=True) |
| 77 | + self._pa_array = pa.chunked_array(arr) |
| 78 | + else: |
| 79 | + self._pa_array = values |
| 80 | + |
| 81 | + @classmethod |
| 82 | + def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False): |
| 83 | + if isinstance(scalars, ListArray): |
| 84 | + return cls(scalars) |
| 85 | + |
| 86 | + values = pa.array(scalars, from_pandas=True) |
| 87 | + if values.type == "null": |
| 88 | + # TODO(wayd): this is a hack to get the tests to pass, but the overall issue |
| 89 | + # is that our extension types don't support parametrization but the pyarrow |
| 90 | + values = pa.array(values, type=pa.list_(pa.null())) |
| 91 | + |
| 92 | + return cls(values) |
| 93 | + |
| 94 | + def __getitem__(self, item): |
| 95 | + # PyArrow does not support NumPy's selection with an equal length |
| 96 | + # mask, so let's convert those to integral positions if needed |
| 97 | + if isinstance(item, np.ndarray) and item.dtype == bool: |
| 98 | + pos = np.array(range(len(item))) |
| 99 | + mask = pos[item] |
| 100 | + return type(self)(self._pa_array.take(mask)) |
| 101 | + elif isinstance(item, int): # scalar case |
| 102 | + return self._pa_array[item] |
| 103 | + |
| 104 | + return type(self)(self._pa_array[item]) |
| 105 | + |
| 106 | + def __len__(self) -> int: |
| 107 | + return len(self._pa_array) |
| 108 | + |
| 109 | + def isna(self): |
| 110 | + return np.array(self._pa_array.is_null()) |
| 111 | + |
| 112 | + def take(self, indexer, allow_fill=False, fill_value=None): |
| 113 | + # TODO: what do we need to do with allow_fill and fill_value here? |
| 114 | + return type(self)(self._pa_array.take(indexer)) |
| 115 | + |
| 116 | + def copy(self): |
| 117 | + return type(self)(self._pa_array.take(pa.array(range(len(self._pa_array))))) |
| 118 | + |
| 119 | + def astype(self, dtype, copy=True): |
| 120 | + if isinstance(dtype, type(self.dtype)) and dtype == self.dtype: |
| 121 | + if copy: |
| 122 | + return self.copy() |
| 123 | + return self |
| 124 | + elif is_string_dtype(dtype) and not is_object_dtype(dtype): |
| 125 | + # numpy has problems with astype(str) for nested elements |
| 126 | + # and pyarrow cannot cast from list[string] to string |
| 127 | + return np.array([str(x) for x in self._pa_array], dtype=dtype) |
| 128 | + |
| 129 | + if not copy: |
| 130 | + raise TypeError(f"astype from ListArray to {dtype} requires a copy") |
| 131 | + |
| 132 | + return np.array(self._pa_array.to_pylist(), dtype=dtype, copy=copy) |
| 133 | + |
| 134 | + @classmethod |
| 135 | + def _concat_same_type(cls, to_concat): |
| 136 | + data = [x._pa_array for x in to_concat] |
| 137 | + return cls(data) |
0 commit comments