@@ -301,7 +301,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
301301 elif self .storage == "pyarrow" and self ._na_value is libmissing .NA :
302302 return ArrowStringArray
303303 elif self .storage == "python" :
304- return StringArrayNumpySemantics
304+ return StringArray
305305 else :
306306 return ArrowStringArrayNumpySemantics
307307
@@ -500,9 +500,14 @@ def _str_map_str_or_object(
500500 result = pa .array (
501501 result , mask = mask , type = pa .large_string (), from_pandas = True
502502 )
503- # error: Too many arguments for "BaseStringArray"
504- return type (self )(result ) # type: ignore[call-arg]
505-
503+ if self .dtype .storage == "python" :
504+ # StringArray
505+ # error: Too many arguments for "BaseStringArray"
506+ return type (self )(result , dtype = self .dtype ) # type: ignore[call-arg]
507+ else :
508+ # ArrowStringArray
509+ # error: Too many arguments for "BaseStringArray"
510+ return type (self )(result ) # type: ignore[call-arg]
506511 else :
507512 # This is when the result type is object. We reach this when
508513 # -> We know the result type is truly object (e.g. .encode returns bytes
@@ -645,36 +650,52 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
645650
646651 # undo the NumpyExtensionArray hack
647652 _typ = "extension"
648- _storage = "python"
649- _na_value : libmissing .NAType | float = libmissing .NA
650653
651- def __init__ (self , values , copy : bool = False ) -> None :
654+ def __init__ (self , values , * , dtype : StringDtype , copy : bool = False ) -> None :
652655 values = extract_array (values )
653656
654657 super ().__init__ (values , copy = copy )
655658 if not isinstance (values , type (self )):
656- self ._validate ()
659+ self ._validate (dtype )
657660 NDArrayBacked .__init__ (
658661 self ,
659662 self ._ndarray ,
660- StringDtype ( storage = self . _storage , na_value = self . _na_value ) ,
663+ dtype ,
661664 )
662665
663- def _validate (self ) -> None :
666+ def _validate (self , dtype : StringDtype ) -> None :
664667 """Validate that we only store NA or strings."""
665- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
666- raise ValueError ("StringArray requires a sequence of strings or pandas.NA" )
667- if self ._ndarray .dtype != "object" :
668- raise ValueError (
669- "StringArray requires a sequence of strings or pandas.NA. Got "
670- f"'{ self ._ndarray .dtype } ' dtype instead."
671- )
672- # Check to see if need to convert Na values to pd.NA
673- if self ._ndarray .ndim > 2 :
674- # Ravel if ndims > 2 b/c no cythonized version available
675- lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
668+
669+ if dtype ._na_value is libmissing .NA :
670+ if len (self ._ndarray ) and not lib .is_string_array (
671+ self ._ndarray , skipna = True
672+ ):
673+ raise ValueError (
674+ "StringArray requires a sequence of strings or pandas.NA"
675+ )
676+ if self ._ndarray .dtype != "object" :
677+ raise ValueError (
678+ "StringArray requires a sequence of strings or pandas.NA. Got "
679+ f"'{ self ._ndarray .dtype } ' dtype instead."
680+ )
681+ # Check to see if need to convert Na values to pd.NA
682+ if self ._ndarray .ndim > 2 :
683+ # Ravel if ndims > 2 b/c no cythonized version available
684+ lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
685+ else :
686+ lib .convert_nans_to_NA (self ._ndarray )
676687 else :
677- lib .convert_nans_to_NA (self ._ndarray )
688+ # Validate that we only store NaN or strings.
689+ if len (self ._ndarray ) and not lib .is_string_array (
690+ self ._ndarray , skipna = True
691+ ):
692+ raise ValueError ("StringArray requires a sequence of strings or NaN" )
693+ if self ._ndarray .dtype != "object" :
694+ raise ValueError (
695+ "StringArray requires a sequence of strings "
696+ "or NaN. Got '{self._ndarray.dtype}' dtype instead."
697+ )
698+ # TODO validate or force NA/None to NaN
678699
679700 def _validate_scalar (self , value ):
680701 # used by NDArrayBackedExtensionIndex.insert
@@ -736,7 +757,7 @@ def _from_sequence_of_strings(
736757 def _empty (cls , shape , dtype ) -> StringArray :
737758 values = np .empty (shape , dtype = object )
738759 values [:] = libmissing .NA
739- return cls (values ).astype (dtype , copy = False )
760+ return cls (values , dtype = dtype ).astype (dtype , copy = False )
740761
741762 def __arrow_array__ (self , type = None ):
742763 """
@@ -936,7 +957,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
936957 if self ._hasna :
937958 na_mask = cast ("npt.NDArray[np.bool_]" , isna (ndarray ))
938959 if np .all (na_mask ):
939- return type (self )(ndarray )
960+ return type (self )(ndarray , dtype = self . dtype )
940961 if skipna :
941962 if name == "cumsum" :
942963 ndarray = np .where (na_mask , "" , ndarray )
@@ -970,7 +991,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
970991 # Argument 2 to "where" has incompatible type "NAType | float"
971992 np_result = np .where (na_mask , self .dtype .na_value , np_result ) # type: ignore[arg-type]
972993
973- result = type (self )(np_result )
994+ result = type (self )(np_result , dtype = self . dtype )
974995 return result
975996
976997 def _wrap_reduction_result (self , axis : AxisInt | None , result ) -> Any :
@@ -1099,29 +1120,3 @@ def _cmp_method(self, other, op):
10991120 return res_arr
11001121
11011122 _arith_method = _cmp_method
1102-
1103-
1104- class StringArrayNumpySemantics (StringArray ):
1105- _storage = "python"
1106- _na_value = np .nan
1107-
1108- def _validate (self ) -> None :
1109- """Validate that we only store NaN or strings."""
1110- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
1111- raise ValueError (
1112- "StringArrayNumpySemantics requires a sequence of strings or NaN"
1113- )
1114- if self ._ndarray .dtype != "object" :
1115- raise ValueError (
1116- "StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
1117- f"'{ self ._ndarray .dtype } ' dtype instead."
1118- )
1119- # TODO validate or force NA/None to NaN
1120-
1121- @classmethod
1122- def _from_sequence (
1123- cls , scalars , * , dtype : Dtype | None = None , copy : bool = False
1124- ) -> Self :
1125- if dtype is None :
1126- dtype = StringDtype (storage = "python" , na_value = np .nan )
1127- return super ()._from_sequence (scalars , dtype = dtype , copy = copy )
0 commit comments