Skip to content

Commit cb7410f

Browse files
fix propagating na_value to Array class + fix some tests
1 parent e29ca8d commit cb7410f

File tree

6 files changed

+35
-39
lines changed

6 files changed

+35
-39
lines changed

pandas/core/arrays/string_.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self, storage=None, na_value=libmissing.NA) -> None:
138138
if storage == "pyarrow_numpy":
139139
# TODO raise a deprecation warning
140140
storage = "pyarrow"
141+
na_value = np.nan
141142

142143
if storage not in {"python", "pyarrow"}:
143144
raise ValueError(

pandas/core/arrays/string_arrow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
131131
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
132132
_dtype: StringDtype # type: ignore[assignment]
133133
_storage = "pyarrow"
134+
_na_value = libmissing.NA
134135

135136
def __init__(self, values) -> None:
136137
_chk_pyarrow_available()
@@ -140,7 +141,7 @@ def __init__(self, values) -> None:
140141
values = pc.cast(values, pa.large_string())
141142

142143
super().__init__(values)
143-
self._dtype = StringDtype(storage=self._storage)
144+
self._dtype = StringDtype(storage=self._storage, na_value=self._na_value)
144145

145146
if not pa.types.is_large_string(self._pa_array.type) and not (
146147
pa.types.is_dictionary(self._pa_array.type)
@@ -598,6 +599,7 @@ def _rank(
598599

599600
class ArrowStringArrayNumpySemantics(ArrowStringArray):
600601
_storage = "pyarrow"
602+
_na_value = np.nan
601603

602604
@classmethod
603605
def _result_converter(cls, values, na=None):

pandas/tests/arrays/string_/test_string.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
)
2121

2222

23-
def na_val(dtype):
24-
if dtype.storage == "pyarrow_numpy":
25-
return np.nan
26-
else:
27-
return pd.NA
28-
29-
3023
@pytest.fixture
3124
def dtype(string_storage):
3225
"""Fixture giving StringDtype from parametrized 'string_storage'"""
@@ -41,22 +34,22 @@ def cls(dtype):
4134

4235
def test_repr(dtype):
4336
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
44-
if dtype.storage == "pyarrow_numpy":
37+
if dtype.na_value is np.nan:
4538
expected = " A\n0 a\n1 NaN\n2 b"
4639
else:
4740
expected = " A\n0 a\n1 <NA>\n2 b"
4841
assert repr(df) == expected
4942

50-
if dtype.storage == "pyarrow_numpy":
43+
if dtype.na_value is np.nan:
5144
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
5245
else:
5346
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
5447
assert repr(df.A) == expected
5548

56-
if dtype.storage == "pyarrow":
49+
if dtype.storage == "pyarrow" and dtype.na_value is pd.NA:
5750
arr_name = "ArrowStringArray"
5851
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
59-
elif dtype.storage == "pyarrow_numpy":
52+
elif dtype.storage == "pyarrow" and dtype.na_value is np.nan:
6053
arr_name = "ArrowStringArrayNumpySemantics"
6154
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
6255
else:
@@ -68,7 +61,7 @@ def test_repr(dtype):
6861
def test_none_to_nan(cls, dtype):
6962
a = cls._from_sequence(["a", None, "b"], dtype=dtype)
7063
assert a[1] is not None
71-
assert a[1] is na_val(a.dtype)
64+
assert a[1] is a.dtype.na_value
7265

7366

7467
def test_setitem_validates(cls, dtype):
@@ -225,7 +218,7 @@ def test_comparison_methods_scalar(comparison_op, dtype):
225218
a = pd.array(["a", None, "c"], dtype=dtype)
226219
other = "a"
227220
result = getattr(a, op_name)(other)
228-
if dtype.storage == "pyarrow_numpy":
221+
if dtype.na_value is np.nan:
229222
expected = np.array([getattr(item, op_name)(other) for item in a])
230223
if comparison_op == operator.ne:
231224
expected[1] = True
@@ -244,7 +237,7 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype):
244237
a = pd.array(["a", None, "c"], dtype=dtype)
245238
result = getattr(a, op_name)(pd.NA)
246239

247-
if dtype.storage == "pyarrow_numpy":
240+
if dtype.na_value is np.nan:
248241
if operator.ne == comparison_op:
249242
expected = np.array([True, True, True])
250243
else:
@@ -271,7 +264,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype):
271264

272265
result = getattr(a, op_name)(other)
273266

274-
if dtype.storage == "pyarrow_numpy":
267+
if dtype.na_value is np.nan:
275268
expected_data = {
276269
"__eq__": [False, False, False],
277270
"__ne__": [True, True, True],
@@ -293,7 +286,7 @@ def test_comparison_methods_array(comparison_op, dtype):
293286
a = pd.array(["a", None, "c"], dtype=dtype)
294287
other = [None, None, "c"]
295288
result = getattr(a, op_name)(other)
296-
if dtype.storage == "pyarrow_numpy":
289+
if dtype.na_value is np.nan:
297290
if operator.ne == comparison_op:
298291
expected = np.array([True, True, False])
299292
else:
@@ -387,7 +380,7 @@ def test_astype_int(dtype):
387380
tm.assert_numpy_array_equal(result, expected)
388381

389382
arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
390-
if dtype.storage == "pyarrow_numpy":
383+
if dtype.na_value is np.nan:
391384
err = ValueError
392385
msg = "cannot convert float NaN to integer"
393386
else:
@@ -441,7 +434,7 @@ def test_min_max(method, skipna, dtype):
441434
expected = "a" if method == "min" else "c"
442435
assert result == expected
443436
else:
444-
assert result is na_val(arr.dtype)
437+
assert result is arr.dtype.na_value
445438

446439

447440
@pytest.mark.parametrize("method", ["min", "max"])
@@ -522,7 +515,7 @@ def test_arrow_roundtrip(dtype, string_storage2, request, using_infer_string):
522515
expected = df.astype(f"string[{string_storage2}]")
523516
tm.assert_frame_equal(result, expected)
524517
# ensure the missing value is represented by NA and not np.nan or None
525-
assert result.loc[2, "a"] is na_val(result["a"].dtype)
518+
assert result.loc[2, "a"] is result["a"].dtype.na_value
526519

527520

528521
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
@@ -556,10 +549,10 @@ def test_arrow_load_from_zero_chunks(
556549

557550

558551
def test_value_counts_na(dtype):
559-
if getattr(dtype, "storage", "") == "pyarrow":
560-
exp_dtype = "int64[pyarrow]"
561-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
552+
if dtype.na_value is np.nan:
562553
exp_dtype = "int64"
554+
elif dtype.storage == "pyarrow":
555+
exp_dtype = "int64[pyarrow]"
563556
else:
564557
exp_dtype = "Int64"
565558
arr = pd.array(["a", "b", "a", pd.NA], dtype=dtype)
@@ -573,10 +566,10 @@ def test_value_counts_na(dtype):
573566

574567

575568
def test_value_counts_with_normalize(dtype):
576-
if getattr(dtype, "storage", "") == "pyarrow":
577-
exp_dtype = "double[pyarrow]"
578-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
569+
if dtype.na_value is np.nan:
579570
exp_dtype = np.float64
571+
elif dtype.storage == "pyarrow":
572+
exp_dtype = "double[pyarrow]"
580573
else:
581574
exp_dtype = "Float64"
582575
ser = pd.Series(["a", "b", "a", pd.NA], dtype=dtype)
@@ -586,10 +579,10 @@ def test_value_counts_with_normalize(dtype):
586579

587580

588581
def test_value_counts_sort_false(dtype):
589-
if getattr(dtype, "storage", "") == "pyarrow":
590-
exp_dtype = "int64[pyarrow]"
591-
elif getattr(dtype, "storage", "") == "pyarrow_numpy":
582+
if dtype.na_value is np.nan:
592583
exp_dtype = "int64"
584+
elif dtype.storage == "pyarrow":
585+
exp_dtype = "int64[pyarrow]"
593586
else:
594587
exp_dtype = "Int64"
595588
ser = pd.Series(["a", "b", "c", "b"], dtype=dtype)
@@ -621,7 +614,7 @@ def test_astype_from_float_dtype(float_dtype, dtype):
621614
def test_to_numpy_returns_pdna_default(dtype):
622615
arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
623616
result = np.array(arr)
624-
expected = np.array(["a", na_val(dtype), "b"], dtype=object)
617+
expected = np.array(["a", dtype.na_value, "b"], dtype=object)
625618
tm.assert_numpy_array_equal(result, expected)
626619

627620

@@ -661,7 +654,7 @@ def test_setitem_scalar_with_mask_validation(dtype):
661654
mask = np.array([False, True, False])
662655

663656
ser[mask] = None
664-
assert ser.array[1] is na_val(ser.dtype)
657+
assert ser.array[1] is ser.dtype.na_value
665658

666659
# for other non-string we should also raise an error
667660
ser = pd.Series(["a", "b", "c"], dtype=dtype)

pandas/tests/arrays/string_/test_string_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,6 @@ def test_pickle_roundtrip(dtype):
260260
def test_string_dtype_error_message():
261261
# GH#55051
262262
pytest.importorskip("pyarrow")
263-
msg = "Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'."
263+
msg = "Storage must be 'python' or 'pyarrow'."
264264
with pytest.raises(ValueError, match=msg):
265265
StringDtype("bla")

pandas/tests/extension/base/methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ def test_value_counts_with_normalize(self, data):
6666
expected = pd.Series(0.0, index=result.index, name="proportion")
6767
expected[result > 0] = 1 / len(values)
6868

69-
if getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
69+
if isinstance(data.dtype, pd.StringDtype) and data.dtype.na_value is np.nan:
70+
# TODO: avoid special-casing
71+
expected = expected.astype("float64")
72+
elif getattr(data.dtype, "storage", "") == "pyarrow" or isinstance(
7073
data.dtype, pd.ArrowDtype
7174
):
7275
# TODO: avoid special-casing
7376
expected = expected.astype("double[pyarrow]")
74-
elif getattr(data.dtype, "storage", "") == "pyarrow_numpy":
75-
# TODO: avoid special-casing
76-
expected = expected.astype("float64")
7777
elif na_value_for_dtype(data.dtype) is pd.NA:
7878
# TODO(GH#44692): avoid special-casing
7979
expected = expected.astype("Float64")

pandas/tests/extension/test_string.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,18 @@ def _get_expected_exception(
190190
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
191191
return (
192192
op_name in ["min", "max"]
193-
or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
193+
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
194194
and op_name in ("any", "all")
195195
)
196196

197197
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
198198
dtype = cast(StringDtype, tm.get_dtype(obj))
199199
if op_name in ["__add__", "__radd__"]:
200200
cast_to = dtype
201+
elif dtype.na_value is np.nan:
202+
cast_to = np.bool_ # type: ignore[assignment]
201203
elif dtype.storage == "pyarrow":
202204
cast_to = "boolean[pyarrow]" # type: ignore[assignment]
203-
elif dtype.storage == "pyarrow_numpy":
204-
cast_to = np.bool_ # type: ignore[assignment]
205205
else:
206206
cast_to = "boolean" # type: ignore[assignment]
207207
return pointwise_result.astype(cast_to)

0 commit comments

Comments
 (0)