Skip to content

Commit 6f35c0e

Browse files
author
Kei
committed
Fallback convert to input dtype is output is all nan or empty array
1 parent 93b5bf3 commit 6f35c0e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pandas/core/dtypes/cast.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481-
def maybe_cast_to_pyarrow_dtype(result: ArrayLike) -> ArrayLike:
481+
def maybe_cast_to_pyarrow_dtype(result: ArrayLike, obj_dtype: Dtype) -> ArrayLike:
482482
"""
483483
Try casting result of a pointwise operation to its pyarrow dtype if
484484
appropriate.
@@ -504,8 +504,12 @@ def maybe_cast_to_pyarrow_dtype(result: ArrayLike) -> ArrayLike:
504504
from pandas.core.construction import array as pd_array
505505

506506
result[isna(result)] = np.nan
507-
pyarrow_result = pa.array(result)
508-
pandas_pyarrow_dtype = ArrowDtype(pyarrow_result.type)
507+
if result.size == 0 or all(isna(result)):
508+
pandas_pyarrow_dtype = obj_dtype
509+
else:
510+
pyarrow_result = pa.array(result)
511+
pandas_pyarrow_dtype = ArrowDtype(pyarrow_result.type)
512+
509513
result = pd_array(result, dtype=pandas_pyarrow_dtype)
510514
except (
511515
ArrowNotImplementedError,

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def agg_series(
918918

919919
result = self._aggregate_series_pure_python(obj, func)
920920
if isinstance(obj._values, ArrowExtensionArray):
921-
out = maybe_cast_to_pyarrow_dtype(result)
921+
out = maybe_cast_to_pyarrow_dtype(result, obj.dtype)
922922
return out
923923

924924
if not isinstance(obj._values, np.ndarray) and not isinstance(

0 commit comments

Comments
 (0)