Skip to content

Commit 48f6a8b

Browse files
committed
ENH: EA._cast_pointwise_result
1 parent 1d22331 commit 48f6a8b

File tree

14 files changed

+76
-204
lines changed

14 files changed

+76
-204
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,18 @@ def _from_sequence_of_strings(
392392
)
393393
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)
394394

395+
def _cast_pointwise_result(self, values) -> ArrayLike:
396+
if len(values) == 0:
397+
# Retain our dtype
398+
return self[:0].copy()
399+
arr = pa.array(values, from_pandas=True)
400+
if isinstance(self.dtype, StringDtype):
401+
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
402+
# ArrowStringArrayNumpySemantics
403+
return type(self)(arr)
404+
return ArrowExtensionArray(arr)
405+
return type(self)(arr)
406+
395407
@classmethod
396408
def _box_pa(
397409
cls, value, pa_type: pa.DataType | None = None

pandas/core/arrays/base.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
cast,
2020
overload,
2121
)
22-
import warnings
2322

2423
import numpy as np
2524

@@ -35,13 +34,11 @@
3534
Substitution,
3635
cache_readonly,
3736
)
38-
from pandas.util._exceptions import find_stack_level
3937
from pandas.util._validators import (
4038
validate_bool_kwarg,
4139
validate_insert_loc,
4240
)
4341

44-
from pandas.core.dtypes.cast import maybe_cast_pointwise_result
4542
from pandas.core.dtypes.common import (
4643
is_list_like,
4744
is_scalar,
@@ -89,7 +86,6 @@
8986
AstypeArg,
9087
AxisInt,
9188
Dtype,
92-
DtypeObj,
9389
FillnaOptions,
9490
InterpolateOptions,
9591
NumpySorter,
@@ -311,38 +307,6 @@ def _from_sequence(
311307
"""
312308
raise AbstractMethodError(cls)
313309

314-
@classmethod
315-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
316-
"""
317-
Strict analogue to _from_sequence, allowing only sequences of scalars
318-
that should be specifically inferred to the given dtype.
319-
320-
Parameters
321-
----------
322-
scalars : sequence
323-
dtype : ExtensionDtype
324-
325-
Raises
326-
------
327-
TypeError or ValueError
328-
329-
Notes
330-
-----
331-
This is called in a try/except block when casting the result of a
332-
pointwise operation.
333-
"""
334-
try:
335-
return cls._from_sequence(scalars, dtype=dtype, copy=False)
336-
except (ValueError, TypeError):
337-
raise
338-
except Exception:
339-
warnings.warn(
340-
"_from_scalars should only raise ValueError or TypeError. "
341-
"Consider overriding _from_scalars where appropriate.",
342-
stacklevel=find_stack_level(),
343-
)
344-
raise
345-
346310
@classmethod
347311
def _from_sequence_of_strings(
348312
cls, strings, *, dtype: ExtensionDtype, copy: bool = False
@@ -371,9 +335,6 @@ def _from_sequence_of_strings(
371335
from a sequence of scalars.
372336
api.extensions.ExtensionArray._from_factorized : Reconstruct an ExtensionArray
373337
after factorization.
374-
api.extensions.ExtensionArray._from_scalars : Strict analogue to _from_sequence,
375-
allowing only sequences of scalars that should be specifically inferred to
376-
the given dtype.
377338
378339
Examples
379340
--------
@@ -416,6 +377,14 @@ def _from_factorized(cls, values, original):
416377
"""
417378
raise AbstractMethodError(cls)
418379

380+
def _cast_pointwise_result(self, values) -> ArrayLike:
381+
"""
382+
Cast the result of a pointwise operation (e.g. Series.map) to an
383+
array, preserve dtype_backend if possible.
384+
"""
385+
values = np.asarray(values, dtype=object)
386+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
387+
419388
# ------------------------------------------------------------------------
420389
# Must be a Sequence
421390
# ------------------------------------------------------------------------
@@ -2842,7 +2811,7 @@ def _maybe_convert(arr):
28422811
# https://github.com/pandas-dev/pandas/issues/22850
28432812
# We catch all regular exceptions here, and fall back
28442813
# to an ndarray.
2845-
res = maybe_cast_pointwise_result(arr, self.dtype, same_dtype=False)
2814+
res = self._cast_pointwise_result(arr)
28462815
if not isinstance(res, type(self)):
28472816
# exception raised in _from_sequence; ensure we have ndarray
28482817
res = np.asarray(arr)

pandas/core/arrays/categorical.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@
103103
AstypeArg,
104104
AxisInt,
105105
Dtype,
106-
DtypeObj,
107106
NpDtype,
108107
Ordered,
109108
Shape,
@@ -529,20 +528,12 @@ def _from_sequence(
529528
) -> Self:
530529
return cls(scalars, dtype=dtype, copy=copy)
531530

532-
@classmethod
533-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
534-
if dtype is None:
535-
# The _from_scalars strictness doesn't make much sense in this case.
536-
raise NotImplementedError
537-
538-
res = cls._from_sequence(scalars, dtype=dtype)
539-
540-
# if there are any non-category elements in scalars, these will be
541-
# converted to NAs in res.
542-
mask = isna(scalars)
543-
if not (mask == res.isna()).all():
544-
# Some non-category element in scalars got converted to NA in res.
545-
raise ValueError
531+
def _cast_pointwise_result(self, values) -> ArrayLike:
532+
res = super()._cast_pointwise_result(values)
533+
cat = type(self)._from_sequence(res, dtype=self.dtype)
534+
if (cat.isna() == isna(res)).all():
535+
# i.e. the conversion was non-lossy
536+
return cat
546537
return res
547538

548539
@overload

pandas/core/arrays/datetimes.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
from pandas._typing import (
8484
ArrayLike,
8585
DateTimeErrorChoices,
86-
DtypeObj,
8786
IntervalClosedType,
8887
TimeAmbiguous,
8988
TimeNonexistent,
@@ -293,14 +292,6 @@ def _scalar_type(self) -> type[Timestamp]:
293292
_dtype: np.dtype[np.datetime64] | DatetimeTZDtype
294293
_freq: BaseOffset | None = None
295294

296-
@classmethod
297-
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
298-
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
299-
# TODO: require any NAs be valid-for-DTA
300-
# TODO: if dtype is passed, check for tzawareness compat?
301-
raise ValueError
302-
return cls._from_sequence(scalars, dtype=dtype)
303-
304295
@classmethod
305296
def _validate_dtype(cls, values, dtype):
306297
# used in TimeLikeOps.__init__

pandas/core/arrays/masked.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self:
147147
values, mask = cls._coerce_to_array(scalars, dtype=dtype, copy=copy)
148148
return cls(values, mask)
149149

150+
def _cast_pointwise_result(self, values) -> ArrayLike:
151+
values = np.asarray(values, dtype=object)
152+
return lib.maybe_convert_objects(values, convert_to_nullable_dtype=True)
153+
150154
@classmethod
151155
@doc(ExtensionArray._empty)
152156
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:

pandas/core/arrays/sparse/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,10 @@ def _from_sequence(
607607
def _from_factorized(cls, values, original) -> Self:
608608
return cls(values, dtype=original.dtype)
609609

610+
def _cast_pointwise_result(self, values):
611+
result = super()._cast_pointwise_result(values)
612+
return type(self)._from_sequence(result)
613+
610614
# ------------------------------------------------------------------------
611615
# Data
612616
# ------------------------------------------------------------------------

pandas/core/arrays/string_.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,6 @@ def tolist(self) -> list:
412412
return [x.tolist() for x in self]
413413
return list(self.to_numpy())
414414

415-
@classmethod
416-
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
417-
if lib.infer_dtype(scalars, skipna=True) not in ["string", "empty"]:
418-
# TODO: require any NAs be valid-for-string
419-
raise ValueError
420-
return cls._from_sequence(scalars, dtype=dtype)
421-
422415
def _formatter(self, boxed: bool = False):
423416
formatter = partial(
424417
printing.pprint_thing,
@@ -732,6 +725,13 @@ def _from_sequence_of_strings(
732725
) -> Self:
733726
return cls._from_sequence(strings, dtype=dtype, copy=copy)
734727

728+
def _cast_pointwise_result(self, values) -> ArrayLike:
729+
result = super()._cast_pointwise_result(values)
730+
if isinstance(result.dtype, StringDtype):
731+
# Ensure we retain our same na_value/storage
732+
result = result.astype(self.dtype)
733+
return result
734+
735735
@classmethod
736736
def _empty(cls, shape, dtype) -> StringArray:
737737
values = np.empty(shape, dtype=object)

pandas/core/dtypes/cast.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -437,80 +437,6 @@ def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
437437
return arr
438438

439439

440-
def maybe_cast_pointwise_result(
441-
result: ArrayLike,
442-
dtype: DtypeObj,
443-
numeric_only: bool = False,
444-
same_dtype: bool = True,
445-
) -> ArrayLike:
446-
"""
447-
Try casting result of a pointwise operation back to the original dtype if
448-
appropriate.
449-
450-
Parameters
451-
----------
452-
result : array-like
453-
Result to cast.
454-
dtype : np.dtype or ExtensionDtype
455-
Input Series from which result was calculated.
456-
numeric_only : bool, default False
457-
Whether to cast only numerics or datetimes as well.
458-
same_dtype : bool, default True
459-
Specify dtype when calling _from_sequence
460-
461-
Returns
462-
-------
463-
result : array-like
464-
result maybe casted to the dtype.
465-
"""
466-
467-
if isinstance(dtype, ExtensionDtype):
468-
cls = dtype.construct_array_type()
469-
if same_dtype:
470-
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
471-
else:
472-
result = _maybe_cast_to_extension_array(cls, result)
473-
474-
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
475-
result = maybe_downcast_to_dtype(result, dtype)
476-
477-
return result
478-
479-
480-
def _maybe_cast_to_extension_array(
481-
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
482-
) -> ArrayLike:
483-
"""
484-
Call to `_from_sequence` that returns the object unchanged on Exception.
485-
486-
Parameters
487-
----------
488-
cls : class, subclass of ExtensionArray
489-
obj : arraylike
490-
Values to pass to cls._from_sequence
491-
dtype : ExtensionDtype, optional
492-
493-
Returns
494-
-------
495-
ExtensionArray or obj
496-
"""
497-
result: ArrayLike
498-
499-
if dtype is not None:
500-
try:
501-
result = cls._from_scalars(obj, dtype=dtype)
502-
except (TypeError, ValueError):
503-
return obj
504-
return result
505-
506-
try:
507-
result = cls._from_sequence(obj, dtype=dtype)
508-
except Exception:
509-
# We can't predict what downstream EA constructors may raise
510-
result = obj
511-
return result
512-
513-
514440
@overload
515441
def ensure_dtype_can_hold_na(dtype: np.dtype) -> np.dtype: ...
516442

pandas/core/groupby/ops.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from pandas.util._decorators import cache_readonly
3636

3737
from pandas.core.dtypes.cast import (
38-
maybe_cast_pointwise_result,
3938
maybe_downcast_to_dtype,
4039
)
4140
from pandas.core.dtypes.common import (
@@ -44,15 +43,13 @@
4443
ensure_platform_int,
4544
ensure_uint64,
4645
is_1d_only_ea_dtype,
47-
is_string_dtype,
4846
)
4947
from pandas.core.dtypes.missing import (
5048
isna,
5149
maybe_fill,
5250
)
5351

5452
from pandas.core.arrays import Categorical
55-
from pandas.core.arrays.arrow.array import ArrowExtensionArray
5653
from pandas.core.frame import DataFrame
5754
from pandas.core.groupby import grouper
5855
from pandas.core.indexes.api import (
@@ -966,29 +963,7 @@ def agg_series(
966963
np.ndarray or ExtensionArray
967964
"""
968965
result = self._aggregate_series_pure_python(obj, func)
969-
npvalues = lib.maybe_convert_objects(result, try_float=False)
970-
971-
if isinstance(obj._values, ArrowExtensionArray):
972-
# When obj.dtype is a string, any object can be cast. Only do so if the
973-
# UDF returned strings or NA values.
974-
if not is_string_dtype(obj.dtype) or lib.is_string_array(
975-
npvalues, skipna=True
976-
):
977-
out = maybe_cast_pointwise_result(
978-
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
979-
)
980-
else:
981-
out = npvalues
982-
983-
elif not isinstance(obj._values, np.ndarray):
984-
# we can preserve a little bit more aggressively with EA dtype
985-
# because maybe_cast_pointwise_result will do a try/except
986-
# with _from_sequence. NB we are assuming here that _from_sequence
987-
# is sufficiently strict that it casts appropriately.
988-
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
989-
else:
990-
out = npvalues
991-
return out
966+
return obj.array._cast_pointwise_result(result)
992967

993968
@final
994969
def _aggregate_series_pure_python(

0 commit comments

Comments
 (0)