Skip to content

Commit 6e83ae0

Browse files
committed
Implement interface tests
1 parent 9d404e5 commit 6e83ae0

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

pandas/core/arrays/list_.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class ListDtype(ArrowDtype):
5454
An ExtensionDtype suitable for storing homogeneous lists of data.
5555
"""
5656

57+
_is_immutable = True # TODO(wayd): should we allow mutability?
58+
5759
def __init__(self, value_dtype: pa.DataType) -> None:
5860
super().__init__(pa.large_list(value_dtype))
5961

@@ -211,3 +213,19 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
211213
return np.array([str(x) for x in self], dtype=dtype)
212214

213215
return super().astype(dtype, copy)
216+
217+
def __eq__(self, other):
218+
if isinstance(other, (pa.ListScalar, pa.LargeListScalar)):
219+
from pandas.arrays import BooleanArray
220+
221+
# TODO: pyarrow.compute does not implement broadcasting equality
222+
# for an array of lists to a listscalar
223+
# TODO: pyarrow doesn't compare missing values as missing???
224+
# arr = pa.array([1, 2, None])
225+
# pc.equal(arr, arr[2]) returns all nulls but
226+
# arr[2] == arr[2] returns True
227+
mask = np.array([False] * len(self))
228+
values = np.array([x == other for x in self._pa_array])
229+
return BooleanArray(values, mask)
230+
231+
return super().__eq__(other)

pandas/tests/extension/list/test_list.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pandas.tests.extension.base.dtype import BaseDtypeTests
1818
from pandas.tests.extension.base.groupby import BaseGroupbyTests
1919
from pandas.tests.extension.base.index import BaseIndexTests
20+
from pandas.tests.extension.base.interface import BaseInterfaceTests
2021
from pandas.tests.extension.base.missing import BaseMissingTests
2122
from pandas.tests.extension.base.ops import ( # noqa: F401
2223
BaseArithmeticOpsTests,
@@ -70,7 +71,7 @@ class TestListArray(
7071
# BaseGetitemTests,
7172
BaseGroupbyTests,
7273
BaseIndexTests,
73-
# BaseInterfaceTests,
74+
BaseInterfaceTests,
7475
# BaseParsingTests,
7576
# BaseMethodsTests,
7677
BaseMissingTests,
@@ -112,6 +113,9 @@ def test_groupby_extension_transform(self, data_for_grouping):
112113
def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
113114
pytest.skip(reason="ListArray does not implement dictionary_encode")
114115

116+
def test_array_interface(self, data):
117+
pytest.skip(reason="ListArrayScalar does not compare to numpy object-dtype")
118+
115119

116120
def test_to_csv(data):
117121
# https://github.com/pandas-dev/pandas/issues/28840

0 commit comments

Comments
 (0)