Skip to content

Commit 4ef96f7

Browse files
author
Kei
committed
In agg series, convert to np values, then cast to pyarrow dtype, account for missing pyarrow dtypes
1 parent 6dc40f5 commit 4ef96f7

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

pandas/core/dtypes/cast.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,11 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481-
def maybe_cast_to_pyarrow_dtype(result: ArrayLike, obj_dtype: Dtype) -> ArrayLike:
481+
def maybe_cast_to_pyarrow_result(result: ArrayLike) -> ArrayLike:
482482
"""
483-
Try casting result of a pointwise operation to its pyarrow dtype if
484-
appropriate.
483+
Try casting result of a pointwise operation to its pyarrow dtype
484+
and arrow extension array if appropriate. If not possible,
485+
returns np.ndarray.
485486
486487
Parameters
487488
----------
@@ -493,34 +494,20 @@ def maybe_cast_to_pyarrow_dtype(result: ArrayLike, obj_dtype: Dtype) -> ArrayLik
493494
result : array-like
494495
result maybe casted to the dtype.
495496
"""
496-
try:
497-
import pyarrow as pa
498-
from pyarrow import (
499-
ArrowInvalid,
500-
ArrowMemoryError,
501-
ArrowNotImplementedError,
502-
)
497+
from pandas.core.construction import array as pd_array
503498

504-
from pandas.core.construction import array as pd_array
505-
506-
stripped_result = result[~isna(result)]
507-
if result.size == 0 or all(isna(stripped_result)):
508-
pandas_pyarrow_dtype = obj_dtype
509-
else:
510-
pyarrow_result = pa.array(stripped_result)
511-
pandas_pyarrow_dtype = ArrowDtype(pyarrow_result.type)
499+
# maybe_convert_objects is unable to detect NA as nan
500+
# (detects it as object instead)
501+
stripped_result = result[~isna(result)]
502+
npvalues = lib.maybe_convert_objects(stripped_result, try_float=False)
512503

513-
result = pd_array(result, dtype=pandas_pyarrow_dtype)
514-
except (
515-
ArrowNotImplementedError,
516-
ArrowInvalid,
517-
ArrowMemoryError,
518-
TypeError,
519-
ValueError,
520-
):
521-
result = lib.maybe_convert_objects(result, try_float=False)
504+
try:
505+
dtype = convert_dtypes(npvalues, dtype_backend="pyarrow")
506+
out = pd_array(result, dtype=dtype)
507+
except (TypeError, ValueError, np.ComplexWarning):
508+
out = npvalues
522509

523-
return result
510+
return out
524511

525512

526513
def _maybe_cast_to_extension_array(
@@ -1080,6 +1067,7 @@ def convert_dtypes(
10801067
inferred_dtype = lib.infer_dtype(input_array)
10811068
else:
10821069
inferred_dtype = input_array.dtype
1070+
orig_inferred_dtype = inferred_dtype
10831071

10841072
if is_string_dtype(inferred_dtype):
10851073
if not convert_string or inferred_dtype == "bytes":
@@ -1177,7 +1165,8 @@ def convert_dtypes(
11771165
elif isinstance(inferred_dtype, StringDtype):
11781166
base_dtype = np.dtype(str)
11791167
else:
1180-
base_dtype = inferred_dtype
1168+
base_dtype = _infer_pyarrow_dtype(input_array, orig_inferred_dtype)
1169+
11811170
if (
11821171
base_dtype.kind == "O" # type: ignore[union-attr]
11831172
and input_array.size > 0
@@ -1188,8 +1177,10 @@ def convert_dtypes(
11881177
pa_type = pa.null()
11891178
else:
11901179
pa_type = to_pyarrow_type(base_dtype)
1180+
11911181
if pa_type is not None:
11921182
inferred_dtype = ArrowDtype(pa_type)
1183+
11931184
elif dtype_backend == "numpy_nullable" and isinstance(inferred_dtype, ArrowDtype):
11941185
# GH 53648
11951186
inferred_dtype = _arrow_dtype_mapping()[inferred_dtype.pyarrow_dtype]
@@ -1199,6 +1190,35 @@ def convert_dtypes(
11991190
return inferred_dtype # type: ignore[return-value]
12001191

12011192

1193+
def _infer_pyarrow_dtype(
1194+
input_array: ArrayLike,
1195+
inferred_dtype: str,
1196+
) -> DtypeObj:
1197+
if inferred_dtype not in ["time", "date", "decimal", "bytes"]:
1198+
return input_array.dtype
1199+
1200+
# For a limited set of dtype
1201+
# Let pyarrow infer dtype from input_array
1202+
import pyarrow as pa
1203+
from pyarrow import (
1204+
ArrowInvalid,
1205+
ArrowMemoryError,
1206+
ArrowNotImplementedError,
1207+
)
1208+
1209+
try:
1210+
pyarrow_array = pa.array(input_array)
1211+
return ArrowDtype(pyarrow_array.type)
1212+
except (
1213+
TypeError,
1214+
ValueError,
1215+
ArrowInvalid,
1216+
ArrowMemoryError,
1217+
ArrowNotImplementedError,
1218+
):
1219+
return input_array.dtype
1220+
1221+
12021222
def maybe_infer_to_datetimelike(
12031223
value: npt.NDArray[np.object_],
12041224
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:

pandas/core/groupby/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from pandas.core.dtypes.cast import (
3838
maybe_cast_pointwise_result,
39-
maybe_cast_to_pyarrow_dtype,
39+
maybe_cast_to_pyarrow_result,
4040
maybe_downcast_to_dtype,
4141
)
4242
from pandas.core.dtypes.common import (
@@ -917,9 +917,9 @@ def agg_series(
917917
"""
918918

919919
result = self._aggregate_series_pure_python(obj, func)
920+
920921
if isinstance(obj._values, ArrowExtensionArray):
921-
out = maybe_cast_to_pyarrow_dtype(result, obj.dtype)
922-
return out
922+
return maybe_cast_to_pyarrow_result(result)
923923

924924
if not isinstance(obj._values, np.ndarray) and not isinstance(
925925
obj._values, ArrowExtensionArray

0 commit comments

Comments
 (0)