Skip to content

Commit c55bc0a

Browse files
committed
Implement first-class List type
1 parent 2c7c6d6 commit c55bc0a

File tree

11 files changed

+195
-157
lines changed

11 files changed

+195
-157
lines changed

pandas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
PeriodDtype,
6262
IntervalDtype,
6363
DatetimeTZDtype,
64+
ListDtype,
6465
StringDtype,
6566
BooleanDtype,
6667
# missing
@@ -261,6 +262,7 @@
261262
"Interval",
262263
"IntervalDtype",
263264
"IntervalIndex",
265+
"ListDtype",
264266
"MultiIndex",
265267
"NaT",
266268
"NamedAgg",

pandas/_testing/asserters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
TimedeltaArray,
5555
)
5656
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
57+
from pandas.core.arrays.list_ import ListDtype
5758
from pandas.core.arrays.string_ import StringDtype
5859
from pandas.core.indexes.api import safe_sort_index
5960

@@ -824,6 +825,11 @@ def assert_extension_array_equal(
824825
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined]
825826
), "wrong missing value sentinels"
826827

828+
# TODO: not every array type may be convertible to NumPy; should catch here
829+
if isinstance(left.dtype, ListDtype) and isinstance(right.dtype, ListDtype):
830+
assert left._pa_array == right._pa_array
831+
return
832+
827833
left_valid = left[~left_na].to_numpy(dtype=object)
828834
right_valid = right[~right_na].to_numpy(dtype=object)
829835
if check_exact:

pandas/core/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
UInt32Dtype,
4141
UInt64Dtype,
4242
)
43+
from pandas.core.arrays.list_ import ListDtype
4344
from pandas.core.arrays.string_ import StringDtype
4445
from pandas.core.construction import array # noqa: ICN001
4546
from pandas.core.flags import Flags
@@ -103,6 +104,7 @@
103104
"Interval",
104105
"IntervalDtype",
105106
"IntervalIndex",
107+
"ListDtype",
106108
"MultiIndex",
107109
"NaT",
108110
"NamedAgg",

pandas/core/arrays/list_.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from __future__ import annotations
2+
3+
from typing import (
4+
TYPE_CHECKING,
5+
ClassVar,
6+
)
7+
8+
import numpy as np
9+
10+
from pandas._libs import missing as libmissing
11+
from pandas.compat import HAS_PYARROW
12+
from pandas.util._decorators import set_module
13+
14+
from pandas.core.dtypes.base import (
15+
ExtensionDtype,
16+
register_extension_dtype,
17+
)
18+
from pandas.core.dtypes.common import (
19+
is_object_dtype,
20+
is_string_dtype,
21+
)
22+
23+
from pandas.core.arrays import ExtensionArray
24+
25+
if TYPE_CHECKING:
26+
from pandas._typing import type_t
27+
28+
import pyarrow as pa
29+
30+
31+
@register_extension_dtype
32+
@set_module("pandas")
33+
class ListDtype(ExtensionDtype):
34+
"""
35+
An ExtensionDtype suitable for storing homogeneous lists of data.
36+
"""
37+
38+
type = list
39+
name: ClassVar[str] = "list"
40+
41+
@property
42+
def na_value(self) -> libmissing.NAType:
43+
return libmissing.NA
44+
45+
@property
46+
def kind(self) -> str:
47+
# TODO: our extension interface says this field should be the
48+
# NumPy type character, but no such thing exists for list
49+
# this assumes a PyArrow large list
50+
return "+L"
51+
52+
@classmethod
53+
def construct_array_type(cls) -> type_t[ListArray]:
54+
"""
55+
Return the array type associated with this dtype.
56+
57+
Returns
58+
-------
59+
type
60+
"""
61+
return ListArray
62+
63+
64+
class ListArray(ExtensionArray):
65+
dtype = ListDtype()
66+
__array_priority__ = 1000
67+
68+
def __init__(self, values: pa.Array | pa.ChunkedArray | list | ListArray) -> None:
69+
if not HAS_PYARROW:
70+
raise NotImplementedError("ListArray requires pyarrow to be installed")
71+
72+
if isinstance(values, type(self)):
73+
self._pa_array = values._pa_array
74+
elif not isinstance(values, pa.ChunkedArray):
75+
# To support NA, we need to create an Array first :-(
76+
arr = pa.array(values, from_pandas=True)
77+
self._pa_array = pa.chunked_array(arr)
78+
else:
79+
self._pa_array = values
80+
81+
@classmethod
82+
def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
83+
if isinstance(scalars, ListArray):
84+
return cls(scalars)
85+
86+
values = pa.array(scalars, from_pandas=True)
87+
if values.type == "null":
88+
# TODO(wayd): this is a hack to get the tests to pass, but the overall issue
89+
# is that our extension types don't support parametrization but the pyarrow
90+
values = pa.array(values, type=pa.list_(pa.null()))
91+
92+
return cls(values)
93+
94+
def __getitem__(self, item):
95+
# PyArrow does not support NumPy's selection with an equal length
96+
# mask, so let's convert those to integral positions if needed
97+
if isinstance(item, np.ndarray) and item.dtype == bool:
98+
pos = np.array(range(len(item)))
99+
mask = pos[item]
100+
return type(self)(self._pa_array.take(mask))
101+
elif isinstance(item, int): # scalar case
102+
return self._pa_array[item]
103+
104+
return type(self)(self._pa_array[item])
105+
106+
def __len__(self) -> int:
107+
return len(self._pa_array)
108+
109+
def isna(self):
110+
return np.array(self._pa_array.is_null())
111+
112+
def take(self, indexer, allow_fill=False, fill_value=None):
113+
# TODO: what do we need to do with allow_fill and fill_value here?
114+
return type(self)(self._pa_array.take(indexer))
115+
116+
def copy(self):
117+
return type(self)(self._pa_array.take(pa.array(range(len(self._pa_array)))))
118+
119+
def astype(self, dtype, copy=True):
120+
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
121+
if copy:
122+
return self.copy()
123+
return self
124+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
125+
# numpy has problems with astype(str) for nested elements
126+
# and pyarrow cannot cast from list[string] to string
127+
return np.array([str(x) for x in self._pa_array], dtype=dtype)
128+
129+
if not copy:
130+
raise TypeError(f"astype from ListArray to {dtype} requires a copy")
131+
132+
return np.array(self._pa_array.to_pylist(), dtype=dtype, copy=copy)
133+
134+
@classmethod
135+
def _concat_same_type(cls, to_concat):
136+
data = [x._pa_array for x in to_concat]
137+
return cls(data)

pandas/core/internals/blocks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,10 @@ def convert_dtypes(
576576
@final
577577
@cache_readonly
578578
def dtype(self) -> DtypeObj:
579-
return self.values.dtype
579+
try:
580+
return self.values.dtype
581+
except AttributeError: # PyArrow fallback
582+
return self.values.type
580583

581584
@final
582585
def astype(
@@ -2234,12 +2237,16 @@ def new_block(
22342237
*,
22352238
ndim: int,
22362239
refs: BlockValuesRefs | None = None,
2240+
dtype: DtypeObj | None,
22372241
) -> Block:
22382242
# caller is responsible for ensuring:
22392243
# - values is NOT a NumpyExtensionArray
22402244
# - check_ndim/ensure_block_shape already checked
22412245
# - maybe_coerce_values already called/unnecessary
2242-
klass = get_block_type(values.dtype)
2246+
if dtype:
2247+
klass = get_block_type(dtype)
2248+
else:
2249+
klass = get_block_type(values.dtype)
22432250
return klass(values, ndim=ndim, placement=placement, refs=refs)
22442251

22452252

pandas/core/internals/managers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,14 +1976,18 @@ def from_blocks(
19761976

19771977
@classmethod
19781978
def from_array(
1979-
cls, array: ArrayLike, index: Index, refs: BlockValuesRefs | None = None
1979+
cls,
1980+
array: ArrayLike,
1981+
dtype: DtypeObj | None,
1982+
index: Index,
1983+
refs: BlockValuesRefs | None = None,
19801984
) -> SingleBlockManager:
19811985
"""
19821986
Constructor for if we have an array that is not yet a Block.
19831987
"""
19841988
array = maybe_coerce_values(array)
19851989
bp = BlockPlacement(slice(0, len(index)))
1986-
block = new_block(array, placement=bp, ndim=1, refs=refs)
1990+
block = new_block(array, placement=bp, ndim=1, refs=refs, dtype=dtype)
19871991
return cls(block, index)
19881992

19891993
def to_2d_mgr(self, columns: Index) -> BlockManager:

pandas/core/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def __init__(
505505
data = data.copy()
506506
else:
507507
data = sanitize_array(data, index, dtype, copy)
508-
data = SingleBlockManager.from_array(data, index, refs=refs)
508+
data = SingleBlockManager.from_array(data, dtype, index, refs=refs)
509509

510510
NDFrame.__init__(self, data)
511511
self.name = name

pandas/io/formats/format.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,11 @@ def format_array(
11031103
List[str]
11041104
"""
11051105
fmt_klass: type[_GenericArrayFormatter]
1106-
if lib.is_np_dtype(values.dtype, "M"):
1106+
if hasattr(values, "type") and values.type == "null":
1107+
fmt_klass = _NullFormatter
1108+
if hasattr(values, "type") and str(values.type).startswith("list"):
1109+
fmt_klass = _ListFormatter
1110+
elif lib.is_np_dtype(values.dtype, "M"):
11071111
fmt_klass = _Datetime64Formatter
11081112
values = cast(DatetimeArray, values)
11091113
elif isinstance(values.dtype, DatetimeTZDtype):
@@ -1467,6 +1471,27 @@ def _format_strings(self) -> list[str]:
14671471
return fmt_values
14681472

14691473

1474+
class _NullFormatter(_GenericArrayFormatter):
1475+
def _format_strings(self) -> list[str]:
1476+
fmt_values = [str(x) for x in self.values]
1477+
return fmt_values
1478+
1479+
1480+
class _ListFormatter(_GenericArrayFormatter):
1481+
def _format_strings(self) -> list[str]:
1482+
# TODO(wayd): This doesn't seem right - where should missing values
1483+
# be handled
1484+
fmt_values = []
1485+
for x in self.values:
1486+
pyval = x.as_py()
1487+
if pyval:
1488+
fmt_values.append(pyval)
1489+
else:
1490+
fmt_values.append("")
1491+
1492+
return fmt_values
1493+
1494+
14701495
class _Datetime64Formatter(_GenericArrayFormatter):
14711496
values: DatetimeArray
14721497

pandas/tests/extension/list/__init__.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)