Skip to content

Commit 64330f0

Browse files
author
Kei
committed
Update implementation to use pyarrow array method
1 parent a54bf58 commit 64330f0

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

pandas/core/dtypes/cast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,42 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481+
def maybe_cast_to_pyarrow_dtype(
482+
result: ArrayLike, converted_result: ArrayLike
483+
) -> ArrayLike:
484+
"""
485+
Try casting result of a pointwise operation to its pyarrow dtype if
486+
appropriate.
487+
488+
Parameters
489+
----------
490+
result : array-like
491+
Result to cast.
492+
493+
Returns
494+
-------
495+
result : array-like
496+
result maybe casted to the dtype.
497+
"""
498+
try:
499+
import pyarrow as pa
500+
from pyarrow import (
501+
ArrowInvalid,
502+
ArrowNotImplementedError,
503+
)
504+
505+
from pandas.core.construction import array as pd_array
506+
507+
result[isna(result)] = np.nan
508+
pyarrow_result = pa.array(result)
509+
pandas_pyarrow_dtype = ArrowDtype(pyarrow_result.type)
510+
result = pd_array(result, dtype=pandas_pyarrow_dtype)
511+
except (ArrowNotImplementedError, ArrowInvalid):
512+
return converted_result
513+
514+
return result
515+
516+
481517
def _maybe_cast_to_extension_array(
482518
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
483519
) -> ArrayLike:

pandas/core/groupby/ops.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from pandas.core.dtypes.cast import (
3838
maybe_cast_pointwise_result,
39+
maybe_cast_to_pyarrow_dtype,
3940
maybe_downcast_to_dtype,
4041
)
4142
from pandas.core.dtypes.common import (
@@ -45,15 +46,13 @@
4546
ensure_uint64,
4647
is_1d_only_ea_dtype,
4748
)
48-
from pandas.core.dtypes.dtypes import ArrowDtype
4949
from pandas.core.dtypes.missing import (
5050
isna,
5151
maybe_fill,
5252
)
5353

5454
from pandas.core.arrays import Categorical
5555
from pandas.core.arrays.arrow.array import ArrowExtensionArray
56-
from pandas.core.construction import array as pd_array
5756
from pandas.core.frame import DataFrame
5857
from pandas.core.groupby import grouper
5958
from pandas.core.indexes.api import (
@@ -927,21 +926,12 @@ def agg_series(
927926
preserve_dtype = True
928927

929928
result = self._aggregate_series_pure_python(obj, func)
930-
931929
npvalues = lib.maybe_convert_objects(result, try_float=False)
930+
932931
if preserve_dtype:
933932
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
934-
elif (
935-
isinstance(obj._values, ArrowExtensionArray)
936-
and npvalues.dtype != np.dtype("object")
937-
and npvalues.dtype != np.dtype("complex128")
938-
):
939-
import pyarrow as pa
940-
941-
pyarrow_dtype = pa.from_numpy_dtype(npvalues.dtype)
942-
pandas_pyarrow_dtype = ArrowDtype(pyarrow_dtype)
943-
out = pd_array(npvalues, dtype=pandas_pyarrow_dtype)
944-
933+
elif isinstance(obj._values, ArrowExtensionArray):
934+
out = maybe_cast_to_pyarrow_dtype(result, npvalues)
945935
else:
946936
out = npvalues
947937
return out

0 commit comments

Comments
 (0)