Skip to content

Commit 4cef00b

Browse files
committed
Implement Reshaping tests
1 parent cf2fb6f commit 4cef00b

File tree

3 files changed

+103
-2
lines changed

3 files changed

+103
-2
lines changed

pandas/core/arrays/list_.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pandas.core.arrays.arrow.array import ArrowExtensionArray
1818

1919
if TYPE_CHECKING:
20+
from collections.abc import Sequence
2021
from pandas._typing import (
2122
type_t,
2223
ArrayLike,
@@ -47,6 +48,20 @@ def string_to_pyarrow_type(string: str) -> pa.DataType:
4748
raise ValueError(f"Cannot map {string} to a pyarrow list type")
4849

4950

51+
def transpose_homogeneous_list(
52+
arrays: Sequence[ListArray],
53+
) -> list[ListArray]:
54+
# TODO: this is the same as transpose_homogeneous_pyarrow
55+
# but returns the ListArray instead of an ArrowExtensionArray
56+
# should consolidate these
57+
arrays = list(arrays)
58+
nrows, ncols = len(arrays[0]), len(arrays)
59+
indices = np.arange(nrows * ncols).reshape(ncols, nrows).T.reshape(-1)
60+
arr = pa.chunked_array([chunk for arr in arrays for chunk in arr._pa_array.chunks])
61+
arr = arr.take(indices)
62+
return [ListArray(arr.slice(i * ncols, ncols)) for i in range(nrows)]
63+
64+
5065
@register_extension_dtype
5166
@set_module("pandas")
5267
class ListDtype(ArrowDtype):
@@ -80,7 +95,10 @@ def name(self) -> str: # type: ignore[override]
8095
"""
8196
A string identifying the data type.
8297
"""
83-
return f"list[{self.pyarrow_dtype.value_type!s}]"
98+
# TODO: reshaping tests require the name list to match the large_list
99+
# implementation; assumedly there are some astype(str(dtype)) casts
100+
# going on. Should fix so this can just be "list[...]" for end user
101+
return f"large_list[{self.pyarrow_dtype.value_type!s}]"
84102

85103
@property
86104
def kind(self) -> str:
@@ -132,6 +150,10 @@ def __init__(
132150
else:
133151
value_type = pa.array(values).type.value_type
134152

153+
# Internally always use large_string instead of string
154+
if value_type == pa.string():
155+
value_type = pa.large_string()
156+
135157
if not isinstance(values, pa.ChunkedArray):
136158
# To support NA, we need to create an Array first :-(
137159
arr = pa.array(values, type=pa.large_list(value_type), from_pandas=True)

pandas/core/frame.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
PeriodArray,
136136
TimedeltaArray,
137137
)
138+
from pandas.core.arrays.list_ import ListDtype
138139
from pandas.core.arrays.sparse import SparseFrameAccessor
139140
from pandas.core.construction import (
140141
ensure_wrapped_if_datetimelike,
@@ -3800,6 +3801,15 @@ def transpose(
38003801
new_values = transpose_homogeneous_masked_arrays(
38013802
cast(Sequence[BaseMaskedArray], self._iter_column_arrays())
38023803
)
3804+
elif isinstance(first_dtype, ListDtype):
3805+
from pandas.core.arrays.list_ import (
3806+
ListArray,
3807+
transpose_homogeneous_list,
3808+
)
3809+
3810+
new_values = transpose_homogeneous_list(
3811+
cast(Sequence[ListArray], self._iter_column_arrays())
3812+
)
38033813
elif isinstance(first_dtype, ArrowDtype):
38043814
# We have arrow EAs with the same dtype. We can transpose faster.
38053815
from pandas.core.arrays.arrow.array import (

pandas/tests/extension/list/test_list.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import operator
23

34
import pyarrow as pa
@@ -30,6 +31,7 @@
3031
)
3132
from pandas.tests.extension.base.printing import BasePrintingTests
3233
from pandas.tests.extension.base.reduce import BaseReduceTests
34+
from pandas.tests.extension.base.reshaping import BaseReshapingTests
3335

3436
# TODO(wayd): This is copied from string tests - is it required here?
3537
# @pytest.fixture(params=[True, False])
@@ -83,7 +85,7 @@ class TestListArray(
8385
BaseUnaryOpsTests,
8486
BasePrintingTests,
8587
BaseReduceTests,
86-
# BaseReshapingTests,
88+
BaseReshapingTests,
8789
# BaseSetitemTests,
8890
Dim2CompatTests,
8991
):
@@ -159,6 +161,73 @@ def test_compare_array(self, data, comparison_op):
159161
def test_invert(self, data):
160162
pytest.skip("ListArray does not implement invert")
161163

164+
def test_merge_on_extension_array(self, data):
165+
pytest.skip("ListArray cannot be factorized")
166+
167+
def test_merge_on_extension_array_duplicates(self, data):
168+
pytest.skip("ListArray cannot be factorized")
169+
170+
@pytest.mark.parametrize(
171+
"index",
172+
[
173+
# Two levels, uniform.
174+
pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"]),
175+
# non-uniform
176+
pd.MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("B", "b")]),
177+
# three levels, non-uniform
178+
pd.MultiIndex.from_product([("A", "B"), ("a", "b", "c"), (0, 1, 2)]),
179+
pd.MultiIndex.from_tuples(
180+
[
181+
("A", "a", 1),
182+
("A", "b", 0),
183+
("A", "a", 0),
184+
("B", "a", 0),
185+
("B", "c", 1),
186+
]
187+
),
188+
],
189+
)
190+
@pytest.mark.parametrize("obj", ["series", "frame"])
191+
def test_unstack(self, data, index, obj):
192+
# TODO: the base class test casts everything to object
193+
# If you remove the object casts, these tests pass...
194+
# Check if still needed in base class
195+
data = data[: len(index)]
196+
if obj == "series":
197+
ser = pd.Series(data, index=index)
198+
else:
199+
ser = pd.DataFrame({"A": data, "B": data}, index=index)
200+
201+
n = index.nlevels
202+
levels = list(range(n))
203+
# [0, 1, 2]
204+
# [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
205+
combinations = itertools.chain.from_iterable(
206+
itertools.permutations(levels, i) for i in range(1, n)
207+
)
208+
209+
for level in combinations:
210+
result = ser.unstack(level=level)
211+
assert all(
212+
isinstance(result[col].array, type(data)) for col in result.columns
213+
)
214+
215+
if obj == "series":
216+
# We should get the same result with to_frame+unstack+droplevel
217+
df = ser.to_frame()
218+
219+
alt = df.unstack(level=level).droplevel(0, axis=1)
220+
tm.assert_frame_equal(result, alt)
221+
222+
# obj_ser = ser.astype(object)
223+
224+
expected = ser.unstack(level=level, fill_value=data.dtype.na_value)
225+
# if obj == "series":
226+
# assert (expected.dtypes == object).all()
227+
228+
# result = result.astype(object)
229+
tm.assert_frame_equal(result, expected)
230+
162231

163232
def test_to_csv(data):
164233
# https://github.com/pandas-dev/pandas/issues/28840

0 commit comments

Comments
 (0)