Skip to content

Commit 5859e96

Browse files
committed
Improve test coverage
1 parent 21a69c9 commit 5859e96

File tree

5 files changed

+126
-97
lines changed

5 files changed

+126
-97
lines changed

pandas/core/arrays/list_.py

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,86 @@
11
from __future__ import annotations
22

3-
from typing import (
4-
TYPE_CHECKING,
5-
ClassVar,
6-
)
3+
from typing import TYPE_CHECKING
74

85
import numpy as np
96

10-
from pandas._libs import missing as libmissing
117
from pandas.compat import HAS_PYARROW
128
from pandas.util._decorators import set_module
139

1410
from pandas.core.dtypes.base import (
1511
ExtensionDtype,
1612
register_extension_dtype,
1713
)
18-
from pandas.core.dtypes.common import (
19-
is_object_dtype,
20-
is_string_dtype,
21-
)
14+
from pandas.core.dtypes.dtypes import ArrowDtype
2215

23-
from pandas.core.arrays import ExtensionArray
16+
from pandas.core.arrays.arrow.array import ArrowExtensionArray
2417

2518
if TYPE_CHECKING:
2619
from pandas._typing import (
2720
type_t,
2821
Shape,
2922
)
3023

24+
import re
25+
3126
import pyarrow as pa
3227

3328

29+
def string_to_pyarrow_type(string: str) -> pa.DataType:
30+
# TODO: combine this with to_pyarrow_type in pandas.core.arrays.arrow ?
31+
pater = r"list\[(.*)\]"
32+
33+
if mtch := re.search(pater, string):
34+
value_type = mtch.groups()[0]
35+
match value_type:
36+
# TODO: is there a pyarrow function get a type from the string?
37+
case "string" | "large_string":
38+
return pa.large_list(pa.large_string())
39+
case "int64":
40+
return pa.large_list(pa.int64())
41+
# TODO: need to implement many more here, including nested
42+
43+
raise ValueError(f"Cannot map {string} to a pyarrow list type")
44+
45+
3446
@register_extension_dtype
3547
@set_module("pandas")
36-
class ListDtype(ExtensionDtype):
48+
class ListDtype(ArrowDtype):
3749
"""
3850
An ExtensionDtype suitable for storing homogeneous lists of data.
3951
"""
4052

41-
type = list
42-
name: ClassVar[str] = "list"
53+
def __init__(self, value_dtype: pa.DataType) -> None:
54+
super().__init__(pa.large_list(value_dtype))
55+
56+
@classmethod
57+
def construct_from_string(cls, string: str):
58+
if not isinstance(string, str):
59+
raise TypeError(
60+
f"'construct_from_string' expects a string, got {type(string)}"
61+
)
62+
63+
try:
64+
pa_type = string_to_pyarrow_type(string)
65+
except ValueError as e:
66+
raise TypeError(
67+
f"Cannot construct a '{cls.__name__}' from '{string}'"
68+
) from e
69+
70+
return cls(pa_type)
4371

4472
@property
45-
def na_value(self) -> libmissing.NAType:
46-
return libmissing.NA
73+
def name(self) -> str: # type: ignore[override]
74+
"""
75+
A string identifying the data type.
76+
"""
77+
return f"list[{self.pyarrow_dtype.value_type!s}]"
4778

4879
@property
4980
def kind(self) -> str:
50-
# TODO: our extension interface says this field should be the
81+
# TODO(wayd): our extension interface says this field should be the
5182
# NumPy type character, but no such thing exists for list
52-
# this assumes a PyArrow large list
83+
# This uses the Arrow C Data exchange code instead
5384
return "+L"
5485

5586
@classmethod
@@ -64,22 +95,34 @@ def construct_array_type(cls) -> type_t[ListArray]:
6495
return ListArray
6596

6697

67-
class ListArray(ExtensionArray):
68-
dtype = ListDtype()
98+
class ListArray(ArrowExtensionArray):
6999
__array_priority__ = 1000
70100

71-
def __init__(self, values: pa.Array | pa.ChunkedArray | list | ListArray) -> None:
101+
def __init__(
102+
self, values: pa.Array | pa.ChunkedArray | list | ListArray, value_type=None
103+
) -> None:
72104
if not HAS_PYARROW:
73105
raise NotImplementedError("ListArray requires pyarrow to be installed")
74106

75107
if isinstance(values, type(self)):
76108
self._pa_array = values._pa_array
77-
elif not isinstance(values, pa.ChunkedArray):
78-
# To support NA, we need to create an Array first :-(
79-
arr = pa.array(values, from_pandas=True)
80-
self._pa_array = pa.chunked_array(arr)
81109
else:
82-
self._pa_array = values
110+
if value_type is None:
111+
if isinstance(values, (pa.Array, pa.ChunkedArray)):
112+
value_type = values.type.value_type
113+
else:
114+
value_type = pa.array(values).type.value_type
115+
116+
if not isinstance(values, pa.ChunkedArray):
117+
# To support NA, we need to create an Array first :-(
118+
arr = pa.array(values, type=pa.large_list(value_type), from_pandas=True)
119+
self._pa_array = pa.chunked_array(arr, type=pa.large_list(value_type))
120+
else:
121+
self._pa_array = values
122+
123+
@property
124+
def _dtype(self):
125+
return ListDtype(self._pa_array.type.value_type)
83126

84127
@classmethod
85128
def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
@@ -100,10 +143,12 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
100143
scalars[i] = None
101144

102145
values = pa.array(scalars, from_pandas=True)
103-
if values.type == "null":
104-
# TODO(wayd): this is a hack to get the tests to pass, but the overall issue
105-
# is that our extension types don't support parametrization but the pyarrow
106-
values = pa.array(values, type=pa.list_(pa.null()))
146+
147+
if values.type == "null" and dtype is not None:
148+
# TODO: the sequencing here seems wrong; just making the tests pass for now
149+
# but this needs a comprehensive review
150+
pa_type = string_to_pyarrow_type(str(dtype))
151+
values = pa.array(values, type=pa_type)
107152

108153
return cls(values)
109154

@@ -114,21 +159,13 @@ def __getitem__(self, item):
114159
pos = np.array(range(len(item)))
115160
mask = pos[item]
116161
return type(self)(self._pa_array.take(mask))
117-
elif isinstance(item, int): # scalar case
162+
elif isinstance(item, int):
118163
return self._pa_array[item]
164+
elif isinstance(item, list):
165+
return type(self)(self._pa_array.take(item))
119166

120167
return type(self)(self._pa_array[item])
121168

122-
def __len__(self) -> int:
123-
return len(self._pa_array)
124-
125-
def isna(self):
126-
return np.array(self._pa_array.is_null())
127-
128-
def take(self, indexer, allow_fill=False, fill_value=None):
129-
# TODO: what do we need to do with allow_fill and fill_value here?
130-
return type(self)(self._pa_array.take(indexer))
131-
132169
@classmethod
133170
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
134171
"""
@@ -149,32 +186,5 @@ def _empty(cls, shape: Shape, dtype: ExtensionDtype):
149186
length = shape[0]
150187
else:
151188
length = shape
152-
return cls._from_sequence([None] * length, dtype=pa.list_(pa.null()))
153189

154-
def copy(self):
155-
mm = pa.default_cpu_memory_manager()
156-
157-
# TODO(wayd): ChunkedArray does not implement copy_to so this
158-
# ends up creating an Array
159-
copied = self._pa_array.combine_chunks().copy_to(mm.device)
160-
return type(self)(copied)
161-
162-
def astype(self, dtype, copy=True):
163-
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
164-
if copy:
165-
return self.copy()
166-
return self
167-
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
168-
# numpy has problems with astype(str) for nested elements
169-
# and pyarrow cannot cast from list[string] to string
170-
return np.array([str(x) for x in self._pa_array], dtype=dtype)
171-
172-
if not copy:
173-
raise TypeError(f"astype from ListArray to {dtype} requires a copy")
174-
175-
return np.array(self._pa_array.to_pylist(), dtype=dtype, copy=copy)
176-
177-
@classmethod
178-
def _concat_same_type(cls, to_concat):
179-
data = [x._pa_array for x in to_concat]
180-
return cls(data)
190+
return cls._from_sequence([None] * length, dtype=dtype)

pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ def __init__(
821821
if len(data) > 0:
822822
if is_dataclass(data[0]):
823823
data = dataclasses_to_dicts(data)
824-
if not isinstance(data, np.ndarray) and treat_as_nested(data):
824+
if not isinstance(data, np.ndarray) and treat_as_nested(data, dtype):
825825
# exclude ndarray as we may have cast it a few lines above
826826
if columns is not None:
827827
columns = ensure_index(columns)

pandas/core/generic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import warnings
2424

2525
import numpy as np
26+
import pyarrow as pa
2627

2728
from pandas._config import config
2829

@@ -7036,7 +7037,8 @@ def fillna(
70367037
value = Series(value)
70377038
value = value.reindex(self.index)
70387039
value = value._values
7039-
elif not is_list_like(value):
7040+
elif isinstance(value, pa.ListScalar) or not is_list_like(value):
7041+
# TODO(wayd): maybe is_list_like should return false for ListScalar?
70407042
pass
70417043
else:
70427044
raise TypeError(
@@ -7100,7 +7102,7 @@ def fillna(
71007102
else:
71017103
return result
71027104

7103-
elif not is_list_like(value):
7105+
elif isinstance(value, pa.ListScalar) or not is_list_like(value):
71047106
if axis == 1:
71057107
result = self.T.fillna(value=value, limit=limit).T
71067108
new_data = result._mgr

pandas/core/internals/construction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
common as com,
4848
)
4949
from pandas.core.arrays import ExtensionArray
50+
from pandas.core.arrays.list_ import ListDtype
5051
from pandas.core.arrays.string_ import StringDtype
5152
from pandas.core.construction import (
5253
array as pd_array,
@@ -453,7 +454,7 @@ def nested_data_to_arrays(
453454
return arrays, columns, index
454455

455456

456-
def treat_as_nested(data) -> bool:
457+
def treat_as_nested(data, dtype) -> bool:
457458
"""
458459
Check if we should use nested_data_to_arrays.
459460
"""
@@ -463,6 +464,7 @@ def treat_as_nested(data) -> bool:
463464
and getattr(data[0], "ndim", 1) == 1
464465
# TODO(wayd): hack so pyarrow list elements don't expand
465466
and not isinstance(data[0], pa.ListScalar)
467+
and not isinstance(dtype, ListDtype)
466468
and not (isinstance(data, ExtensionArray) and data.ndim == 2)
467469
)
468470

pandas/tests/extension/list/test_list.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
1+
import pyarrow as pa
12
import pytest
23

34
import pandas as pd
5+
import pandas._testing as tm
46
from pandas.core.arrays.list_ import (
57
ListArray,
68
ListDtype,
79
)
810
from pandas.tests.extension.base.accumulate import BaseAccumulateTests
9-
from pandas.tests.extension.base.casting import BaseCastingTests
1011
from pandas.tests.extension.base.constructors import BaseConstructorsTests
1112
from pandas.tests.extension.base.dim2 import ( # noqa: F401
1213
Dim2CompatTests,
1314
NDArrayBacked2DTests,
1415
)
15-
from pandas.tests.extension.base.dtype import BaseDtypeTests
16-
from pandas.tests.extension.base.getitem import BaseGetitemTests
17-
from pandas.tests.extension.base.groupby import BaseGroupbyTests
1816
from pandas.tests.extension.base.index import BaseIndexTests
19-
from pandas.tests.extension.base.interface import BaseInterfaceTests
20-
from pandas.tests.extension.base.io import BaseParsingTests
21-
from pandas.tests.extension.base.methods import BaseMethodsTests
2217
from pandas.tests.extension.base.missing import BaseMissingTests
2318
from pandas.tests.extension.base.ops import ( # noqa: F401
2419
BaseArithmeticOpsTests,
@@ -28,14 +23,16 @@
2823
)
2924
from pandas.tests.extension.base.printing import BasePrintingTests
3025
from pandas.tests.extension.base.reduce import BaseReduceTests
31-
from pandas.tests.extension.base.reshaping import BaseReshapingTests
32-
from pandas.tests.extension.base.setitem import BaseSetitemTests
3326

27+
# TODO(wayd): This is copied from string tests - is it required here?
28+
# @pytest.fixture(params=[True, False])
29+
# def chunked(request):
30+
# return request.param
3431

3532

3633
@pytest.fixture
3734
def dtype():
38-
return ListDtype()
35+
return ListDtype(pa.large_string())
3936

4037

4138
@pytest.fixture
@@ -46,28 +43,46 @@ def data():
4643
return ListArray(data)
4744

4845

46+
@pytest.fixture
47+
def data_missing(dtype):
48+
"""Length 2 array with [NA, Valid]"""
49+
arr = dtype.construct_array_type()._from_sequence([pd.NA, [1, 2, 3]], dtype=dtype)
50+
return arr
51+
52+
4953
class TestListArray(
5054
BaseAccumulateTests,
51-
#BaseCastingTests,
55+
# BaseCastingTests,
5256
BaseConstructorsTests,
53-
#BaseDtypeTests,
54-
#BaseGetitemTests,
55-
#BaseGroupbyTests,
57+
# BaseDtypeTests,
58+
# BaseGetitemTests,
59+
# BaseGroupbyTests,
5660
BaseIndexTests,
57-
#BaseInterfaceTests,
58-
BaseParsingTests,
59-
#BaseMethodsTests,
60-
#BaseMissingTests,
61-
#BaseArithmeticOpsTests,
62-
#BaseComparisonOpsTests,
63-
#BaseUnaryOpsTests,
64-
#BasePrintingTests,
61+
# BaseInterfaceTests,
62+
# BaseParsingTests,
63+
# BaseMethodsTests,
64+
BaseMissingTests,
65+
# BaseArithmeticOpsTests,
66+
# BaseComparisonOpsTests,
67+
# BaseUnaryOpsTests,
68+
BasePrintingTests,
6569
BaseReduceTests,
66-
#BaseReshapingTests,
67-
#BaseSetitemTests,
70+
# BaseReshapingTests,
71+
# BaseSetitemTests,
6872
Dim2CompatTests,
6973
):
70-
...
74+
# TODO(wayd): The tests here are copied from test_arrow.py
75+
# It appears the TestArrowArray class has different expectations around
76+
# when copies should be made then the base.ExtensionTests
77+
# Assuming intentional, maybe in the long term this should just
78+
# inherit from TestArrowArray
79+
def test_fillna_no_op_returns_copy(self, data):
80+
data = data[~data.isna()]
81+
82+
valid = data[0]
83+
result = data.fillna(valid)
84+
assert result is not data
85+
tm.assert_extension_array_equal(result, data)
7186

7287

7388
def test_to_csv(data):

0 commit comments

Comments
 (0)