Skip to content

Commit cc345db

Browse files
committed
Update list accessor tests
1 parent 5a2b113 commit cc345db

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

pandas/core/arrays/arrow/accessors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from pandas.core.dtypes.common import is_list_like
2020

21+
from pandas.core.arrays.list_ import ListDtype
22+
2123
if not pa_version_under10p1:
2224
import pyarrow as pa
2325
import pyarrow.compute as pc
@@ -106,7 +108,7 @@ def len(self) -> Series:
106108
... [1, 2, 3],
107109
... [3],
108110
... ],
109-
... dtype=pd.ArrowDtype(pa.list_(pa.int64())),
111+
... dtype=pd.ListDtype(pa.int64()),
110112
... )
111113
>>> s.list.len()
112114
0 3
@@ -189,7 +191,7 @@ def __getitem__(self, key: int | slice) -> Series:
189191
sliced = pc.list_slice(self._pa_array, start, stop, step)
190192
return Series(
191193
sliced,
192-
dtype=ArrowDtype(sliced.type),
194+
dtype=ListDtype(sliced.type.value_type),
193195
index=self._data.index,
194196
name=self._data.name,
195197
)

pandas/core/arrays/list_.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,7 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
186186
# TypeError: object of type 'NoneType' has no len() if you have
187187
# pa.ListScalar(None). Upstream issue in Arrow - see:
188188
# https://github.com/apache/arrow/issues/40319
189-
for i in range(len(scalars)):
190-
if not scalars[i].is_valid:
191-
scalars[i] = None
192-
193-
values = pa.array(scalars, from_pandas=True)
189+
values = pa.array(scalars.to_pylist(), from_pandas=True)
194190

195191
if values.type == "null" and dtype is not None:
196192
pa_type = string_to_pyarrow_type(str(dtype))

pandas/tests/series/accessors/test_list_accessor.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pandas import (
66
ArrowDtype,
7+
ListDtype,
78
Series,
89
)
910
import pandas._testing as tm
@@ -16,15 +17,16 @@
1617
@pytest.mark.parametrize(
1718
"list_dtype",
1819
(
19-
pa.list_(pa.int64()),
20-
pa.list_(pa.int64(), list_size=3),
21-
pa.large_list(pa.int64()),
20+
ArrowDtype(pa.list_(pa.int64())),
21+
ArrowDtype(pa.list_(pa.int64(), list_size=3)),
22+
ArrowDtype(pa.large_list(pa.int64())),
23+
ListDtype(pa.int64()),
2224
),
2325
)
2426
def test_list_getitem(list_dtype):
2527
ser = Series(
2628
[[1, 2, 3], [4, None, 5], None],
27-
dtype=ArrowDtype(list_dtype),
29+
dtype=list_dtype,
2830
name="a",
2931
)
3032
actual = ser.list[1]
@@ -36,7 +38,7 @@ def test_list_getitem_index():
3638
# GH 58425
3739
ser = Series(
3840
[[1, 2, 3], [4, None, 5], None],
39-
dtype=ArrowDtype(pa.list_(pa.int64())),
41+
dtype=ListDtype(pa.int64()),
4042
index=[1, 3, 7],
4143
name="a",
4244
)
@@ -53,7 +55,7 @@ def test_list_getitem_index():
5355
def test_list_getitem_slice():
5456
ser = Series(
5557
[[1, 2, 3], [4, None, 5], None],
56-
dtype=ArrowDtype(pa.list_(pa.int64())),
58+
dtype=ListDtype(pa.int64()),
5759
index=[1, 3, 7],
5860
name="a",
5961
)
@@ -66,7 +68,7 @@ def test_list_getitem_slice():
6668
actual = ser.list[1:None:None]
6769
expected = Series(
6870
[[2, 3], [None, 5], None],
69-
dtype=ArrowDtype(pa.list_(pa.int64())),
71+
dtype=ListDtype(pa.int64()),
7072
index=[1, 3, 7],
7173
name="a",
7274
)
@@ -76,18 +78,18 @@ def test_list_getitem_slice():
7678
def test_list_len():
7779
ser = Series(
7880
[[1, 2, 3], [4, None], None],
79-
dtype=ArrowDtype(pa.list_(pa.int64())),
81+
dtype=ListDtype(pa.int64()),
8082
name="a",
8183
)
8284
actual = ser.list.len()
83-
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()), name="a")
85+
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int64()), name="a")
8486
tm.assert_series_equal(actual, expected)
8587

8688

8789
def test_list_flatten():
8890
ser = Series(
8991
[[1, 2, 3], None, [4, None], [], [7, 8]],
90-
dtype=ArrowDtype(pa.list_(pa.int64())),
92+
dtype=ListDtype(pa.int64()),
9193
name="a",
9294
)
9395
actual = ser.list.flatten()
@@ -103,7 +105,7 @@ def test_list_flatten():
103105
def test_list_getitem_slice_invalid():
104106
ser = Series(
105107
[[1, 2, 3], [4, None, 5], None],
106-
dtype=ArrowDtype(pa.list_(pa.int64())),
108+
dtype=ListDtype(pa.int64()),
107109
)
108110
if pa_version_under11p0:
109111
with pytest.raises(
@@ -133,15 +135,16 @@ def test_list_accessor_non_list_dtype():
133135
@pytest.mark.parametrize(
134136
"list_dtype",
135137
(
136-
pa.list_(pa.int64()),
137-
pa.list_(pa.int64(), list_size=3),
138-
pa.large_list(pa.int64()),
138+
ArrowDtype(pa.list_(pa.int64())),
139+
ArrowDtype(pa.list_(pa.int64(), list_size=3)),
140+
ArrowDtype(pa.large_list(pa.int64())),
141+
ListDtype(pa.int64()),
139142
),
140143
)
141144
def test_list_getitem_invalid_index(list_dtype):
142145
ser = Series(
143146
[[1, 2, 3], [4, None, 5], None],
144-
dtype=ArrowDtype(list_dtype),
147+
dtype=list_dtype,
145148
)
146149
with pytest.raises(pa.lib.ArrowInvalid, match="Index -1 is out of bounds"):
147150
ser.list[-1]
@@ -154,7 +157,7 @@ def test_list_getitem_invalid_index(list_dtype):
154157
def test_list_accessor_not_iterable():
155158
ser = Series(
156159
[[1, 2, 3], [4, None], None],
157-
dtype=ArrowDtype(pa.list_(pa.int64())),
160+
dtype=ListDtype(pa.int64()),
158161
)
159162
with pytest.raises(TypeError, match="'ListAccessor' object is not iterable"):
160163
iter(ser.list)

0 commit comments

Comments
 (0)