@@ -319,6 +319,8 @@ class BaseStringArray(ExtensionArray):
319319 Mixin class for StringArray, ArrowStringArray.
320320 """
321321
322+ dtype : StringDtype
323+
322324 @doc (ExtensionArray .tolist )
323325 def tolist (self ) -> list :
324326 if self .ndim > 1 :
@@ -332,6 +334,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
332334 raise ValueError
333335 return cls ._from_sequence (scalars , dtype = dtype )
334336
337+ def _str_map_str_or_object (
338+ self ,
339+ dtype ,
340+ na_value ,
341+ arr : np .ndarray ,
342+ f ,
343+ mask : npt .NDArray [np .bool_ ],
344+ convert : bool ,
345+ ):
346+ # _str_map helper for case where dtype is either string dtype or object
347+ if is_string_dtype (dtype ) and not is_object_dtype (dtype ):
348+ # i.e. StringDtype
349+ result = lib .map_infer_mask (
350+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
351+ )
352+ if self .dtype .storage == "pyarrow" :
353+ import pyarrow as pa
354+
355+ result = pa .array (
356+ result , mask = mask , type = pa .large_string (), from_pandas = True
357+ )
358+ # error: Too many arguments for "BaseStringArray"
359+ return type (self )(result ) # type: ignore[call-arg]
360+
361+ else :
362+ # This is when the result type is object. We reach this when
363+ # -> We know the result type is truly object (e.g. .encode returns bytes
364+ # or .findall returns a list).
365+ # -> We don't know the result type. E.g. `.get` can return anything.
366+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
367+
335368
336369# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
337370# incompatible with definition in base class "ExtensionArray"
@@ -697,9 +730,53 @@ def _cmp_method(self, other, op):
697730 # base class "NumpyExtensionArray" defined the type as "float")
698731 _str_na_value = libmissing .NA # type: ignore[assignment]
699732
733+ def _str_map_nan_semantics (
734+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
735+ ):
736+ if dtype is None :
737+ dtype = self .dtype
738+ if na_value is None :
739+ na_value = self .dtype .na_value
740+
741+ mask = isna (self )
742+ arr = np .asarray (self )
743+ convert = convert and not np .all (mask )
744+
745+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
746+ na_value_is_na = isna (na_value )
747+ if na_value_is_na :
748+ if is_integer_dtype (dtype ):
749+ na_value = 0
750+ else :
751+ na_value = True
752+
753+ result = lib .map_infer_mask (
754+ arr ,
755+ f ,
756+ mask .view ("uint8" ),
757+ convert = False ,
758+ na_value = na_value ,
759+ dtype = np .dtype (cast (type , dtype )),
760+ )
761+ if na_value_is_na and mask .any ():
762+ if is_integer_dtype (dtype ):
763+ result = result .astype ("float64" )
764+ else :
765+ result = result .astype ("object" )
766+ result [mask ] = np .nan
767+ return result
768+
769+ else :
770+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
771+
700772 def _str_map (
701773 self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
702774 ):
775+ if self .dtype .na_value is np .nan :
776+ return self ._str_map_nan_semantics (
777+ f , na_value = na_value , dtype = dtype , convert = convert
778+ )
779+
703780 from pandas .arrays import BooleanArray
704781
705782 if dtype is None :
@@ -739,18 +816,8 @@ def _str_map(
739816
740817 return constructor (result , mask )
741818
742- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
743- # i.e. StringDtype
744- result = lib .map_infer_mask (
745- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
746- )
747- return StringArray (result )
748819 else :
749- # This is when the result type is object. We reach this when
750- # -> We know the result type is truly object (e.g. .encode returns bytes
751- # or .findall returns a list).
752- # -> We don't know the result type. E.g. `.get` can return anything.
753- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
820+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
754821
755822
756823class StringArrayNumpySemantics (StringArray ):
@@ -817,52 +884,3 @@ def value_counts(self, dropna: bool = True) -> Series:
817884 # ------------------------------------------------------------------------
818885 # String methods interface
819886 _str_na_value = np .nan
820-
821- def _str_map (
822- self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
823- ):
824- if dtype is None :
825- dtype = self .dtype
826- if na_value is None :
827- na_value = self .dtype .na_value
828-
829- mask = isna (self )
830- arr = np .asarray (self )
831- convert = convert and not np .all (mask )
832-
833- if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
834- na_value_is_na = isna (na_value )
835- if na_value_is_na :
836- if is_integer_dtype (dtype ):
837- na_value = 0
838- else :
839- na_value = True
840-
841- result = lib .map_infer_mask (
842- arr ,
843- f ,
844- mask .view ("uint8" ),
845- convert = False ,
846- na_value = na_value ,
847- dtype = np .dtype (cast (type , dtype )),
848- )
849- if na_value_is_na and mask .any ():
850- if is_integer_dtype (dtype ):
851- result = result .astype ("float64" )
852- else :
853- result = result .astype ("object" )
854- result [mask ] = np .nan
855- return result
856-
857- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
858- # i.e. StringDtype
859- result = lib .map_infer_mask (
860- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
861- )
862- return type (self )(result )
863- else :
864- # This is when the result type is object. We reach this when
865- # -> We know the result type is truly object (e.g. .encode returns bytes
866- # or .findall returns a list).
867- # -> We don't know the result type. E.g. `.get` can return anything.
868- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
0 commit comments