diff --git a/asv_bench/benchmarks/strings.py b/asv_bench/benchmarks/strings.py index 467fab857d306..b62b926398c33 100644 --- a/asv_bench/benchmarks/strings.py +++ b/asv_bench/benchmarks/strings.py @@ -8,6 +8,7 @@ DataFrame, Index, Series, + StringDtype, ) from pandas.arrays import StringArray @@ -290,10 +291,10 @@ def setup(self): self.series_arr_nan = np.concatenate([self.series_arr, np.array([NA] * 1000)]) def time_string_array_construction(self): - StringArray(self.series_arr) + StringArray(self.series_arr, dtype=StringDtype()) def time_string_array_with_nan_construction(self): - StringArray(self.series_arr_nan) + StringArray(self.series_arr_nan, dtype=StringDtype()) def peakmem_stringarray_construction(self): - StringArray(self.series_arr) + StringArray(self.series_arr, dtype=StringDtype()) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index da270da342ee6..8e406877a952a 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -300,7 +300,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]: elif self.storage == "pyarrow" and self._na_value is libmissing.NA: return ArrowStringArray elif self.storage == "python": - return StringArrayNumpySemantics + return StringArray else: return ArrowStringArray @@ -487,8 +487,10 @@ def _str_map_str_or_object( ) # error: "BaseStringArray" has no attribute "_from_pyarrow_array" return self._from_pyarrow_array(result) # type: ignore[attr-defined] - # error: Too many arguments for "BaseStringArray" - return type(self)(result) # type: ignore[call-arg] + else: + # StringArray + # error: Too many arguments for "BaseStringArray" + return type(self)(result, dtype=self.dtype) # type: ignore[call-arg] else: # This is when the result type is object. We reach this when @@ -578,6 +580,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc] nan-likes(``None``, ``np.nan``) for the ``values`` parameter in addition to strings and :attr:`pandas.NA` + dtype : StringDtype + Dtype for the array. copy : bool, default False Whether to copy the array of data. @@ -632,36 +636,56 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc] # undo the NumpyExtensionArray hack _typ = "extension" - _storage = "python" - _na_value: libmissing.NAType | float = libmissing.NA - def __init__(self, values, copy: bool = False) -> None: + def __init__( + self, values, *, dtype: StringDtype | None = None, copy: bool = False + ) -> None: + if dtype is None: + dtype = StringDtype() values = extract_array(values) super().__init__(values, copy=copy) if not isinstance(values, type(self)): - self._validate() + self._validate(dtype) NDArrayBacked.__init__( self, self._ndarray, - StringDtype(storage=self._storage, na_value=self._na_value), + dtype, ) - def _validate(self) -> None: + def _validate(self, dtype: StringDtype) -> None: """Validate that we only store NA or strings.""" - if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): - raise ValueError("StringArray requires a sequence of strings or pandas.NA") - if self._ndarray.dtype != "object": - raise ValueError( - "StringArray requires a sequence of strings or pandas.NA. Got " - f"'{self._ndarray.dtype}' dtype instead." - ) - # Check to see if need to convert Na values to pd.NA - if self._ndarray.ndim > 2: - # Ravel if ndims > 2 b/c no cythonized version available - lib.convert_nans_to_NA(self._ndarray.ravel("K")) + + if dtype._na_value is libmissing.NA: + if len(self._ndarray) and not lib.is_string_array( + self._ndarray, skipna=True + ): + raise ValueError( + "StringArray requires a sequence of strings or pandas.NA" + ) + if self._ndarray.dtype != "object": + raise ValueError( + "StringArray requires a sequence of strings or pandas.NA. Got " + f"'{self._ndarray.dtype}' dtype instead." + ) + # Check to see if need to convert Na values to pd.NA + if self._ndarray.ndim > 2: + # Ravel if ndims > 2 b/c no cythonized version available + lib.convert_nans_to_NA(self._ndarray.ravel("K")) + else: + lib.convert_nans_to_NA(self._ndarray) else: - lib.convert_nans_to_NA(self._ndarray) + # Validate that we only store NaN or strings. + if len(self._ndarray) and not lib.is_string_array( + self._ndarray, skipna=True + ): + raise ValueError("StringArray requires a sequence of strings or NaN") + if self._ndarray.dtype != "object": + raise ValueError( + "StringArray requires a sequence of strings " + "or NaN. Got '{self._ndarray.dtype}' dtype instead." + ) + # TODO validate or force NA/None to NaN def _validate_scalar(self, value): # used by NDArrayBackedExtensionIndex.insert @@ -729,8 +753,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike: @classmethod def _empty(cls, shape, dtype) -> StringArray: values = np.empty(shape, dtype=object) - values[:] = libmissing.NA - return cls(values).astype(dtype, copy=False) + values[:] = dtype.na_value + return cls(values, dtype=dtype).astype(dtype, copy=False) def __arrow_array__(self, type=None): """ @@ -930,7 +954,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra if self._hasna: na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray)) if np.all(na_mask): - return type(self)(ndarray) + return type(self)(ndarray, dtype=self.dtype) if skipna: if name == "cumsum": ndarray = np.where(na_mask, "", ndarray) @@ -964,7 +988,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra # Argument 2 to "where" has incompatible type "NAType | float" np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type] - result = type(self)(np_result) + result = type(self)(np_result, dtype=self.dtype) return result def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any: @@ -1043,7 +1067,7 @@ def _cmp_method(self, other, op): and other.dtype.na_value is libmissing.NA ): # NA has priority of NaN semantics - return NotImplemented + return op(self.astype(other.dtype, copy=False), other) if isinstance(other, ArrowExtensionArray): if isinstance(other, BaseStringArray): @@ -1093,29 +1117,3 @@ def _cmp_method(self, other, op): return res_arr _arith_method = _cmp_method - - -class StringArrayNumpySemantics(StringArray): - _storage = "python" - _na_value = np.nan - - def _validate(self) -> None: - """Validate that we only store NaN or strings.""" - if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): - raise ValueError( - "StringArrayNumpySemantics requires a sequence of strings or NaN" - ) - if self._ndarray.dtype != "object": - raise ValueError( - "StringArrayNumpySemantics requires a sequence of strings or NaN. Got " - f"'{self._ndarray.dtype}' dtype instead." - ) - # TODO validate or force NA/None to NaN - - @classmethod - def _from_sequence( - cls, scalars, *, dtype: Dtype | None = None, copy: bool = False - ) -> Self: - if dtype is None: - dtype = StringDtype(storage="python", na_value=np.nan) - return super()._from_sequence(scalars, dtype=dtype, copy=copy) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 9dae3ae384255..7139322482216 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -21,7 +21,6 @@ import pandas as pd import pandas._testing as tm -from pandas.core.arrays.string_ import StringArrayNumpySemantics from pandas.core.arrays.string_arrow import ( ArrowStringArray, ) @@ -115,7 +114,7 @@ def test_repr(dtype): arr_name = "ArrowStringArray" expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str" elif dtype.storage == "python" and dtype.na_value is np.nan: - arr_name = "StringArrayNumpySemantics" + arr_name = "StringArray" expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str" else: arr_name = "StringArray" @@ -433,44 +432,45 @@ def test_comparison_methods_list(comparison_op, dtype): def test_constructor_raises(cls): if cls is pd.arrays.StringArray: msg = "StringArray requires a sequence of strings or pandas.NA" - elif cls is StringArrayNumpySemantics: - msg = "StringArrayNumpySemantics requires a sequence of strings or NaN" + kwargs = {"dtype": pd.StringDtype()} else: msg = "Unsupported type '' for ArrowExtensionArray" + kwargs = {} with pytest.raises(ValueError, match=msg): - cls(np.array(["a", "b"], dtype="S1")) + cls(np.array(["a", "b"], dtype="S1"), **kwargs) with pytest.raises(ValueError, match=msg): - cls(np.array([])) + cls(np.array([]), **kwargs) - if cls is pd.arrays.StringArray or cls is StringArrayNumpySemantics: + if cls is pd.arrays.StringArray: # GH#45057 np.nan and None do NOT raise, as they are considered valid NAs # for string dtype - cls(np.array(["a", np.nan], dtype=object)) - cls(np.array(["a", None], dtype=object)) + cls(np.array(["a", np.nan], dtype=object), **kwargs) + cls(np.array(["a", None], dtype=object), **kwargs) else: with pytest.raises(ValueError, match=msg): - cls(np.array(["a", np.nan], dtype=object)) + cls(np.array(["a", np.nan], dtype=object), **kwargs) with pytest.raises(ValueError, match=msg): - cls(np.array(["a", None], dtype=object)) + cls(np.array(["a", None], dtype=object), **kwargs) with pytest.raises(ValueError, match=msg): - cls(np.array(["a", pd.NaT], dtype=object)) + cls(np.array(["a", pd.NaT], dtype=object), **kwargs) with pytest.raises(ValueError, match=msg): - cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object)) + cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object), **kwargs) with pytest.raises(ValueError, match=msg): - cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object)) + cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object), **kwargs) @pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA]) def test_constructor_nan_like(na): - expected = pd.arrays.StringArray(np.array(["a", pd.NA])) - tm.assert_extension_array_equal( - pd.arrays.StringArray(np.array(["a", na], dtype="object")), expected + expected = pd.arrays.StringArray(np.array(["a", pd.NA]), dtype=pd.StringDtype()) + result = pd.arrays.StringArray( + np.array(["a", na], dtype="object"), dtype=pd.StringDtype() ) + tm.assert_extension_array_equal(result, expected) @pytest.mark.parametrize("copy", [True, False]) @@ -487,10 +487,10 @@ def test_from_sequence_no_mutate(copy, cls, dtype): expected = cls( pa.array(na_arr, type=pa.string(), from_pandas=True), dtype=dtype ) - elif cls is StringArrayNumpySemantics: - expected = cls(nan_arr) + elif dtype.na_value is np.nan: + expected = cls(nan_arr, dtype=dtype) else: - expected = cls(na_arr) + expected = cls(na_arr, dtype=dtype) tm.assert_extension_array_equal(result, expected) tm.assert_numpy_array_equal(nan_arr, expected_input) diff --git a/pandas/tests/base/test_conversion.py b/pandas/tests/base/test_conversion.py index 2ef0e49399e21..8890b4509d954 100644 --- a/pandas/tests/base/test_conversion.py +++ b/pandas/tests/base/test_conversion.py @@ -21,9 +21,9 @@ NumpyExtensionArray, PeriodArray, SparseArray, + StringArray, TimedeltaArray, ) -from pandas.core.arrays.string_ import StringArrayNumpySemantics from pandas.core.arrays.string_arrow import ArrowStringArray @@ -222,7 +222,7 @@ def test_iter_box_period(self): ) def test_values_consistent(arr, expected_type, dtype, using_infer_string): if using_infer_string and dtype == "object": - expected_type = ArrowStringArray if HAS_PYARROW else StringArrayNumpySemantics + expected_type = ArrowStringArray if HAS_PYARROW else StringArray l_values = Series(arr)._values r_values = pd.Index(arr)._values assert type(l_values) is expected_type diff --git a/pandas/tests/extension/test_common.py b/pandas/tests/extension/test_common.py index 5eda0f00f54ca..40192cbc83a01 100644 --- a/pandas/tests/extension/test_common.py +++ b/pandas/tests/extension/test_common.py @@ -93,8 +93,13 @@ def __getitem__(self, item): def test_ellipsis_index(): # GH#42430 1D slices over extension types turn into N-dimensional slices # over ExtensionArrays + dtype = pd.StringDtype() df = pd.DataFrame( - {"col1": CapturingStringArray(np.array(["hello", "world"], dtype=object))} + { + "col1": CapturingStringArray( + np.array(["hello", "world"], dtype=object), dtype=dtype + ) + } ) _ = df.iloc[:1] diff --git a/pandas/tests/io/parser/test_upcast.py b/pandas/tests/io/parser/test_upcast.py index bc4c4c2e24e9c..c17b7b6871945 100644 --- a/pandas/tests/io/parser/test_upcast.py +++ b/pandas/tests/io/parser/test_upcast.py @@ -14,7 +14,6 @@ BooleanArray, FloatingArray, IntegerArray, - StringArray, ) @@ -95,7 +94,7 @@ def test_maybe_upcast_object(val, string_storage): if string_storage == "python": exp_val = "c" if val == "c" else NA - expected = StringArray(np.array(["a", "b", exp_val], dtype=np.object_)) + expected = pd.array(["a", "b", exp_val], dtype=pd.StringDtype()) else: exp_val = "c" if val == "c" else None expected = ArrowStringArray(pa.array(["a", "b", exp_val])) diff --git a/pandas/tests/io/test_orc.py b/pandas/tests/io/test_orc.py index efb3dffecd856..2c193c968e2b5 100644 --- a/pandas/tests/io/test_orc.py +++ b/pandas/tests/io/test_orc.py @@ -12,7 +12,6 @@ import pandas as pd from pandas import read_orc import pandas._testing as tm -from pandas.core.arrays import StringArray pytest.importorskip("pyarrow.orc") @@ -368,13 +367,9 @@ def test_orc_dtype_backend_numpy_nullable(): expected = pd.DataFrame( { - "string": StringArray(np.array(["a", "b", "c"], dtype=np.object_)), - "string_with_nan": StringArray( - np.array(["a", pd.NA, "c"], dtype=np.object_) - ), - "string_with_none": StringArray( - np.array(["a", pd.NA, "c"], dtype=np.object_) - ), + "string": pd.array(["a", "b", "c"], dtype=pd.StringDtype()), + "string_with_nan": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()), + "string_with_none": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()), "int": pd.Series([1, 2, 3], dtype="Int64"), "int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"), "na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"),