Skip to content

Commit f8d9b5d

Browse files
committed
REF: get rid of StringArrayNumpySemantics
1 parent 3940df8 commit f8d9b5d

File tree

6 files changed

+83
-80
lines changed

6 files changed

+83
-80
lines changed

pandas/core/arrays/string_.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
301301
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
302302
return ArrowStringArray
303303
elif self.storage == "python":
304-
return StringArrayNumpySemantics
304+
return StringArray
305305
else:
306306
return ArrowStringArrayNumpySemantics
307307

@@ -500,9 +500,14 @@ def _str_map_str_or_object(
500500
result = pa.array(
501501
result, mask=mask, type=pa.large_string(), from_pandas=True
502502
)
503-
# error: Too many arguments for "BaseStringArray"
504-
return type(self)(result) # type: ignore[call-arg]
505-
503+
if self.dtype.storage == "python":
504+
# StringArray
505+
# error: Too many arguments for "BaseStringArray"
506+
return type(self)(result, dtype=self.dtype) # type: ignore[call-arg]
507+
else:
508+
# ArrowStringArray
509+
# error: Too many arguments for "BaseStringArray"
510+
return type(self)(result) # type: ignore[call-arg]
506511
else:
507512
# This is when the result type is object. We reach this when
508513
# -> We know the result type is truly object (e.g. .encode returns bytes
@@ -645,36 +650,52 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
645650

646651
# undo the NumpyExtensionArray hack
647652
_typ = "extension"
648-
_storage = "python"
649-
_na_value: libmissing.NAType | float = libmissing.NA
650653

651-
def __init__(self, values, copy: bool = False) -> None:
654+
def __init__(self, values, *, dtype: StringDtype, copy: bool = False) -> None:
652655
values = extract_array(values)
653656

654657
super().__init__(values, copy=copy)
655658
if not isinstance(values, type(self)):
656-
self._validate()
659+
self._validate(dtype)
657660
NDArrayBacked.__init__(
658661
self,
659662
self._ndarray,
660-
StringDtype(storage=self._storage, na_value=self._na_value),
663+
dtype,
661664
)
662665

663-
def _validate(self) -> None:
666+
def _validate(self, dtype: StringDtype) -> None:
664667
"""Validate that we only store NA or strings."""
665-
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
666-
raise ValueError("StringArray requires a sequence of strings or pandas.NA")
667-
if self._ndarray.dtype != "object":
668-
raise ValueError(
669-
"StringArray requires a sequence of strings or pandas.NA. Got "
670-
f"'{self._ndarray.dtype}' dtype instead."
671-
)
672-
# Check to see if need to convert Na values to pd.NA
673-
if self._ndarray.ndim > 2:
674-
# Ravel if ndims > 2 b/c no cythonized version available
675-
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
668+
669+
if dtype._na_value is libmissing.NA:
670+
if len(self._ndarray) and not lib.is_string_array(
671+
self._ndarray, skipna=True
672+
):
673+
raise ValueError(
674+
"StringArray requires a sequence of strings or pandas.NA"
675+
)
676+
if self._ndarray.dtype != "object":
677+
raise ValueError(
678+
"StringArray requires a sequence of strings or pandas.NA. Got "
679+
f"'{self._ndarray.dtype}' dtype instead."
680+
)
681+
# Check to see if need to convert Na values to pd.NA
682+
if self._ndarray.ndim > 2:
683+
# Ravel if ndims > 2 b/c no cythonized version available
684+
lib.convert_nans_to_NA(self._ndarray.ravel("K"))
685+
else:
686+
lib.convert_nans_to_NA(self._ndarray)
676687
else:
677-
lib.convert_nans_to_NA(self._ndarray)
688+
# Validate that we only store NaN or strings.
689+
if len(self._ndarray) and not lib.is_string_array(
690+
self._ndarray, skipna=True
691+
):
692+
raise ValueError("StringArray requires a sequence of strings or NaN")
693+
if self._ndarray.dtype != "object":
694+
raise ValueError(
695+
"StringArray requires a sequence of strings "
696+
"or NaN. Got '{self._ndarray.dtype}' dtype instead."
697+
)
698+
# TODO validate or force NA/None to NaN
678699

679700
def _validate_scalar(self, value):
680701
# used by NDArrayBackedExtensionIndex.insert
@@ -736,7 +757,7 @@ def _from_sequence_of_strings(
736757
def _empty(cls, shape, dtype) -> StringArray:
737758
values = np.empty(shape, dtype=object)
738759
values[:] = libmissing.NA
739-
return cls(values).astype(dtype, copy=False)
760+
return cls(values, dtype=dtype).astype(dtype, copy=False)
740761

741762
def __arrow_array__(self, type=None):
742763
"""
@@ -936,7 +957,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
936957
if self._hasna:
937958
na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray))
938959
if np.all(na_mask):
939-
return type(self)(ndarray)
960+
return type(self)(ndarray, dtype=self.dtype)
940961
if skipna:
941962
if name == "cumsum":
942963
ndarray = np.where(na_mask, "", ndarray)
@@ -970,7 +991,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
970991
# Argument 2 to "where" has incompatible type "NAType | float"
971992
np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type]
972993

973-
result = type(self)(np_result)
994+
result = type(self)(np_result, dtype=self.dtype)
974995
return result
975996

976997
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
@@ -1099,29 +1120,3 @@ def _cmp_method(self, other, op):
10991120
return res_arr
11001121

11011122
_arith_method = _cmp_method
1102-
1103-
1104-
class StringArrayNumpySemantics(StringArray):
1105-
_storage = "python"
1106-
_na_value = np.nan
1107-
1108-
def _validate(self) -> None:
1109-
"""Validate that we only store NaN or strings."""
1110-
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
1111-
raise ValueError(
1112-
"StringArrayNumpySemantics requires a sequence of strings or NaN"
1113-
)
1114-
if self._ndarray.dtype != "object":
1115-
raise ValueError(
1116-
"StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
1117-
f"'{self._ndarray.dtype}' dtype instead."
1118-
)
1119-
# TODO validate or force NA/None to NaN
1120-
1121-
@classmethod
1122-
def _from_sequence(
1123-
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False
1124-
) -> Self:
1125-
if dtype is None:
1126-
dtype = StringDtype(storage="python", na_value=np.nan)
1127-
return super()._from_sequence(scalars, dtype=dtype, copy=copy)

pandas/tests/arrays/string_/test_string.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import pandas as pd
2323
import pandas._testing as tm
24-
from pandas.core.arrays.string_ import StringArrayNumpySemantics
2524
from pandas.core.arrays.string_arrow import (
2625
ArrowStringArray,
2726
ArrowStringArrayNumpySemantics,
@@ -116,7 +115,7 @@ def test_repr(dtype):
116115
arr_name = "ArrowStringArrayNumpySemantics"
117116
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
118117
elif dtype.storage == "python" and dtype.na_value is np.nan:
119-
arr_name = "StringArrayNumpySemantics"
118+
arr_name = "StringArray"
120119
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
121120
else:
122121
arr_name = "StringArray"
@@ -434,44 +433,45 @@ def test_comparison_methods_list(comparison_op, dtype):
434433
def test_constructor_raises(cls):
435434
if cls is pd.arrays.StringArray:
436435
msg = "StringArray requires a sequence of strings or pandas.NA"
437-
elif cls is StringArrayNumpySemantics:
438-
msg = "StringArrayNumpySemantics requires a sequence of strings or NaN"
436+
kwargs = {"dtype": pd.StringDtype()}
439437
else:
440438
msg = "Unsupported type '<class 'numpy.ndarray'>' for ArrowExtensionArray"
439+
kwargs = {}
441440

442441
with pytest.raises(ValueError, match=msg):
443-
cls(np.array(["a", "b"], dtype="S1"))
442+
cls(np.array(["a", "b"], dtype="S1"), **kwargs)
444443

445444
with pytest.raises(ValueError, match=msg):
446-
cls(np.array([]))
445+
cls(np.array([]), **kwargs)
447446

448-
if cls is pd.arrays.StringArray or cls is StringArrayNumpySemantics:
447+
if cls is pd.arrays.StringArray:
449448
# GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
450449
# for string dtype
451-
cls(np.array(["a", np.nan], dtype=object))
452-
cls(np.array(["a", None], dtype=object))
450+
cls(np.array(["a", np.nan], dtype=object), **kwargs)
451+
cls(np.array(["a", None], dtype=object), **kwargs)
453452
else:
454453
with pytest.raises(ValueError, match=msg):
455-
cls(np.array(["a", np.nan], dtype=object))
454+
cls(np.array(["a", np.nan], dtype=object), **kwargs)
456455
with pytest.raises(ValueError, match=msg):
457-
cls(np.array(["a", None], dtype=object))
456+
cls(np.array(["a", None], dtype=object), **kwargs)
458457

459458
with pytest.raises(ValueError, match=msg):
460-
cls(np.array(["a", pd.NaT], dtype=object))
459+
cls(np.array(["a", pd.NaT], dtype=object), **kwargs)
461460

462461
with pytest.raises(ValueError, match=msg):
463-
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object))
462+
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object), **kwargs)
464463

465464
with pytest.raises(ValueError, match=msg):
466-
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object))
465+
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object), **kwargs)
467466

468467

469468
@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
470469
def test_constructor_nan_like(na):
471-
expected = pd.arrays.StringArray(np.array(["a", pd.NA]))
472-
tm.assert_extension_array_equal(
473-
pd.arrays.StringArray(np.array(["a", na], dtype="object")), expected
470+
expected = pd.arrays.StringArray(np.array(["a", pd.NA]), dtype=pd.StringDtype())
471+
result = pd.arrays.StringArray(
472+
np.array(["a", na], dtype="object"), dtype=pd.StringDtype()
474473
)
474+
tm.assert_extension_array_equal(result, expected)
475475

476476

477477
@pytest.mark.parametrize("copy", [True, False])
@@ -486,10 +486,10 @@ def test_from_sequence_no_mutate(copy, cls, dtype):
486486
import pyarrow as pa
487487

488488
expected = cls(pa.array(na_arr, type=pa.string(), from_pandas=True))
489-
elif cls is StringArrayNumpySemantics:
490-
expected = cls(nan_arr)
489+
elif dtype.na_value is np.nan:
490+
expected = cls(nan_arr, dtype=dtype)
491491
else:
492-
expected = cls(na_arr)
492+
expected = cls(na_arr, dtype=dtype)
493493

494494
tm.assert_extension_array_equal(result, expected)
495495
tm.assert_numpy_array_equal(nan_arr, expected_input)

pandas/tests/base/test_conversion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
NumpyExtensionArray,
2222
PeriodArray,
2323
SparseArray,
24+
StringArray,
2425
TimedeltaArray,
2526
)
26-
from pandas.core.arrays.string_ import StringArrayNumpySemantics
2727
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
2828

2929

@@ -222,9 +222,7 @@ def test_iter_box_period(self):
222222
)
223223
def test_values_consistent(arr, expected_type, dtype, using_infer_string):
224224
if using_infer_string and dtype == "object":
225-
expected_type = (
226-
ArrowStringArrayNumpySemantics if HAS_PYARROW else StringArrayNumpySemantics
227-
)
225+
expected_type = ArrowStringArrayNumpySemantics if HAS_PYARROW else StringArray
228226
l_values = Series(arr)._values
229227
r_values = pd.Index(arr)._values
230228
assert type(l_values) is expected_type

pandas/tests/extension/test_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,13 @@ def __getitem__(self, item):
9393
def test_ellipsis_index():
9494
# GH#42430 1D slices over extension types turn into N-dimensional slices
9595
# over ExtensionArrays
96+
dtype = pd.StringDtype()
9697
df = pd.DataFrame(
97-
{"col1": CapturingStringArray(np.array(["hello", "world"], dtype=object))}
98+
{
99+
"col1": CapturingStringArray(
100+
np.array(["hello", "world"], dtype=object), dtype=dtype
101+
)
102+
}
98103
)
99104
_ = df.iloc[:1]
100105

pandas/tests/io/parser/test_upcast.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ def test_maybe_upcast_object(val, string_storage):
9595

9696
if string_storage == "python":
9797
exp_val = "c" if val == "c" else NA
98-
expected = StringArray(np.array(["a", "b", exp_val], dtype=np.object_))
98+
dtype = pd.StringDtype()
99+
expected = StringArray(
100+
np.array(["a", "b", exp_val], dtype=np.object_), dtype=dtype
101+
)
99102
else:
100103
exp_val = "c" if val == "c" else None
101104
expected = ArrowStringArray(pa.array(["a", "b", exp_val]))

pandas/tests/io/test_orc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,14 @@ def test_orc_dtype_backend_numpy_nullable():
368368

369369
expected = pd.DataFrame(
370370
{
371-
"string": StringArray(np.array(["a", "b", "c"], dtype=np.object_)),
371+
"string": StringArray(
372+
np.array(["a", "b", "c"], dtype=np.object_), dtype=pd.StringDtype()
373+
),
372374
"string_with_nan": StringArray(
373-
np.array(["a", pd.NA, "c"], dtype=np.object_)
375+
np.array(["a", pd.NA, "c"], dtype=np.object_), dtype=pd.StringDtype()
374376
),
375377
"string_with_none": StringArray(
376-
np.array(["a", pd.NA, "c"], dtype=np.object_)
378+
np.array(["a", pd.NA, "c"], dtype=np.object_), dtype=pd.StringDtype()
377379
),
378380
"int": pd.Series([1, 2, 3], dtype="Int64"),
379381
"int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"),

0 commit comments

Comments
 (0)