Skip to content

Commit 842f561

Browse files
author
Kei
committed
Update according to pr comments
1 parent affde38 commit 842f561

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

pandas/core/dtypes/cast.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,7 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481-
def maybe_cast_to_pyarrow_dtype(
482-
result: ArrayLike, converted_result: ArrayLike
483-
) -> ArrayLike:
481+
def maybe_cast_to_pyarrow_dtype(result: ArrayLike) -> ArrayLike:
484482
"""
485483
Try casting result of a pointwise operation to its pyarrow dtype if
486484
appropriate.
@@ -499,6 +497,7 @@ def maybe_cast_to_pyarrow_dtype(
499497
import pyarrow as pa
500498
from pyarrow import (
501499
ArrowInvalid,
500+
ArrowMemoryError,
502501
ArrowNotImplementedError,
503502
)
504503

@@ -508,8 +507,14 @@ def maybe_cast_to_pyarrow_dtype(
508507
pyarrow_result = pa.array(result)
509508
pandas_pyarrow_dtype = ArrowDtype(pyarrow_result.type)
510509
result = pd_array(result, dtype=pandas_pyarrow_dtype)
511-
except (ArrowNotImplementedError, ArrowInvalid):
512-
return converted_result
510+
except (
511+
ArrowNotImplementedError,
512+
ArrowInvalid,
513+
ArrowMemoryError,
514+
TypeError,
515+
ValueError,
516+
):
517+
result = lib.maybe_convert_objects(result, try_float=False)
513518

514519
return result
515520

pandas/core/groupby/ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,11 @@ def agg_series(
916916
np.ndarray or ExtensionArray
917917
"""
918918

919+
result = self._aggregate_series_pure_python(obj, func)
920+
if isinstance(obj._values, ArrowExtensionArray):
921+
out = maybe_cast_to_pyarrow_dtype(result)
922+
return out
923+
919924
if not isinstance(obj._values, np.ndarray) and not isinstance(
920925
obj._values, ArrowExtensionArray
921926
):
@@ -925,15 +930,12 @@ def agg_series(
925930
# is sufficiently strict that it casts appropriately.
926931
preserve_dtype = True
927932

928-
result = self._aggregate_series_pure_python(obj, func)
929933
npvalues = lib.maybe_convert_objects(result, try_float=False)
930-
931934
if preserve_dtype:
932935
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
933-
elif isinstance(obj._values, ArrowExtensionArray):
934-
out = maybe_cast_to_pyarrow_dtype(result, npvalues)
935936
else:
936937
out = npvalues
938+
937939
return out
938940

939941
@final

0 commit comments

Comments
 (0)