@@ -315,6 +315,8 @@ class BaseStringArray(ExtensionArray):
315
315
Mixin class for StringArray, ArrowStringArray.
316
316
"""
317
317
318
+ dtype: StringDtype
319
+
318
320
@doc(ExtensionArray.tolist)
319
321
def tolist(self):
320
322
if self.ndim > 1:
@@ -328,6 +330,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
328
330
raise ValueError
329
331
return cls._from_sequence(scalars, dtype=dtype)
330
332
333
+ def _str_map_str_or_object(
334
+ self,
335
+ dtype,
336
+ na_value,
337
+ arr: np.ndarray,
338
+ f,
339
+ mask: npt.NDArray[np.bool_],
340
+ convert: bool,
341
+ ):
342
+ # _str_map helper for case where dtype is either string dtype or object
343
+ if is_string_dtype(dtype) and not is_object_dtype(dtype):
344
+ # i.e. StringDtype
345
+ result = lib.map_infer_mask(
346
+ arr, f, mask.view("uint8"), convert=False, na_value=na_value
347
+ )
348
+ if self.dtype.storage == "pyarrow":
349
+ import pyarrow as pa
350
+
351
+ result = pa.array(
352
+ result, mask=mask, type=pa.large_string(), from_pandas=True
353
+ )
354
+ # error: Too many arguments for "BaseStringArray"
355
+ return type(self)(result) # type: ignore[call-arg]
356
+
357
+ else:
358
+ # This is when the result type is object. We reach this when
359
+ # -> We know the result type is truly object (e.g. .encode returns bytes
360
+ # or .findall returns a list).
361
+ # -> We don't know the result type. E.g. `.get` can return anything.
362
+ return lib.map_infer_mask(arr, f, mask.view("uint8"))
363
+
331
364
332
365
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
333
366
# incompatible with definition in base class "ExtensionArray"
@@ -682,9 +715,53 @@ def _cmp_method(self, other, op):
682
715
# base class "NumpyExtensionArray" defined the type as "float")
683
716
_str_na_value = libmissing.NA # type: ignore[assignment]
684
717
718
+ def _str_map_nan_semantics(
719
+ self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
720
+ ):
721
+ if dtype is None:
722
+ dtype = self.dtype
723
+ if na_value is None:
724
+ na_value = self.dtype.na_value
725
+
726
+ mask = isna(self)
727
+ arr = np.asarray(self)
728
+ convert = convert and not np.all(mask)
729
+
730
+ if is_integer_dtype(dtype) or is_bool_dtype(dtype):
731
+ na_value_is_na = isna(na_value)
732
+ if na_value_is_na:
733
+ if is_integer_dtype(dtype):
734
+ na_value = 0
735
+ else:
736
+ na_value = True
737
+
738
+ result = lib.map_infer_mask(
739
+ arr,
740
+ f,
741
+ mask.view("uint8"),
742
+ convert=False,
743
+ na_value=na_value,
744
+ dtype=np.dtype(cast(type, dtype)),
745
+ )
746
+ if na_value_is_na and mask.any():
747
+ if is_integer_dtype(dtype):
748
+ result = result.astype("float64")
749
+ else:
750
+ result = result.astype("object")
751
+ result[mask] = np.nan
752
+ return result
753
+
754
+ else:
755
+ return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
756
+
685
757
def _str_map(
686
758
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
687
759
):
760
+ if self.dtype.na_value is np.nan:
761
+ return self._str_map_nan_semantics(
762
+ f, na_value=na_value, dtype=dtype, convert=convert
763
+ )
764
+
688
765
from pandas.arrays import BooleanArray
689
766
690
767
if dtype is None:
@@ -724,18 +801,8 @@ def _str_map(
724
801
725
802
return constructor(result, mask)
726
803
727
- elif is_string_dtype(dtype) and not is_object_dtype(dtype):
728
- # i.e. StringDtype
729
- result = lib.map_infer_mask(
730
- arr, f, mask.view("uint8"), convert=False, na_value=na_value
731
- )
732
- return StringArray(result)
733
804
else:
734
- # This is when the result type is object. We reach this when
735
- # -> We know the result type is truly object (e.g. .encode returns bytes
736
- # or .findall returns a list).
737
- # -> We don't know the result type. E.g. `.get` can return anything.
738
- return lib.map_infer_mask(arr, f, mask.view("uint8"))
805
+ return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
739
806
740
807
741
808
class StringArrayNumpySemantics(StringArray):
@@ -802,52 +869,3 @@ def value_counts(self, dropna: bool = True) -> Series:
802
869
# ------------------------------------------------------------------------
803
870
# String methods interface
804
871
_str_na_value = np.nan
805
-
806
- def _str_map(
807
- self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
808
- ):
809
- if dtype is None:
810
- dtype = self.dtype
811
- if na_value is None:
812
- na_value = self.dtype.na_value
813
-
814
- mask = isna(self)
815
- arr = np.asarray(self)
816
- convert = convert and not np.all(mask)
817
-
818
- if is_integer_dtype(dtype) or is_bool_dtype(dtype):
819
- na_value_is_na = isna(na_value)
820
- if na_value_is_na:
821
- if is_integer_dtype(dtype):
822
- na_value = 0
823
- else:
824
- na_value = True
825
-
826
- result = lib.map_infer_mask(
827
- arr,
828
- f,
829
- mask.view("uint8"),
830
- convert=False,
831
- na_value=na_value,
832
- dtype=np.dtype(cast(type, dtype)),
833
- )
834
- if na_value_is_na and mask.any():
835
- if is_integer_dtype(dtype):
836
- result = result.astype("float64")
837
- else:
838
- result = result.astype("object")
839
- result[mask] = np.nan
840
- return result
841
-
842
- elif is_string_dtype(dtype) and not is_object_dtype(dtype):
843
- # i.e. StringDtype
844
- result = lib.map_infer_mask(
845
- arr, f, mask.view("uint8"), convert=False, na_value=na_value
846
- )
847
- return type(self)(result)
848
- else:
849
- # This is when the result type is object. We reach this when
850
- # -> We know the result type is truly object (e.g. .encode returns bytes
851
- # or .findall returns a list).
852
- # -> We don't know the result type. E.g. `.get` can return anything.
853
- return lib.map_infer_mask(arr, f, mask.view("uint8"))
0 commit comments