@@ -478,10 +478,11 @@ def maybe_cast_pointwise_result(
478
478
return result
479
479
480
480
481
- def maybe_cast_to_pyarrow_dtype (result : ArrayLike , obj_dtype : Dtype ) -> ArrayLike :
481
+ def maybe_cast_to_pyarrow_result (result : ArrayLike ) -> ArrayLike :
482
482
"""
483
- Try casting result of a pointwise operation to its pyarrow dtype if
484
- appropriate.
483
+ Try casting result of a pointwise operation to its pyarrow dtype
484
+ and arrow extension array if appropriate. If not possible,
485
+ returns np.ndarray.
485
486
486
487
Parameters
487
488
----------
@@ -493,34 +494,20 @@ def maybe_cast_to_pyarrow_dtype(result: ArrayLike, obj_dtype: Dtype) -> ArrayLik
493
494
result : array-like
494
495
result maybe casted to the dtype.
495
496
"""
496
- try :
497
- import pyarrow as pa
498
- from pyarrow import (
499
- ArrowInvalid ,
500
- ArrowMemoryError ,
501
- ArrowNotImplementedError ,
502
- )
497
+ from pandas .core .construction import array as pd_array
503
498
504
- from pandas .core .construction import array as pd_array
505
-
506
- stripped_result = result [~ isna (result )]
507
- if result .size == 0 or all (isna (stripped_result )):
508
- pandas_pyarrow_dtype = obj_dtype
509
- else :
510
- pyarrow_result = pa .array (stripped_result )
511
- pandas_pyarrow_dtype = ArrowDtype (pyarrow_result .type )
499
+ # maybe_convert_objects is unable to detect NA as nan
500
+ # (detects it as object instead)
501
+ stripped_result = result [~ isna (result )]
502
+ npvalues = lib .maybe_convert_objects (stripped_result , try_float = False )
512
503
513
- result = pd_array (result , dtype = pandas_pyarrow_dtype )
514
- except (
515
- ArrowNotImplementedError ,
516
- ArrowInvalid ,
517
- ArrowMemoryError ,
518
- TypeError ,
519
- ValueError ,
520
- ):
521
- result = lib .maybe_convert_objects (result , try_float = False )
504
+ try :
505
+ dtype = convert_dtypes (npvalues , dtype_backend = "pyarrow" )
506
+ out = pd_array (result , dtype = dtype )
507
+ except (TypeError , ValueError , np .ComplexWarning ):
508
+ out = npvalues
522
509
523
- return result
510
+ return out
524
511
525
512
526
513
def _maybe_cast_to_extension_array (
@@ -1080,6 +1067,7 @@ def convert_dtypes(
1080
1067
inferred_dtype = lib .infer_dtype (input_array )
1081
1068
else :
1082
1069
inferred_dtype = input_array .dtype
1070
+ orig_inferred_dtype = inferred_dtype
1083
1071
1084
1072
if is_string_dtype (inferred_dtype ):
1085
1073
if not convert_string or inferred_dtype == "bytes" :
@@ -1177,7 +1165,8 @@ def convert_dtypes(
1177
1165
elif isinstance (inferred_dtype , StringDtype ):
1178
1166
base_dtype = np .dtype (str )
1179
1167
else :
1180
- base_dtype = inferred_dtype
1168
+ base_dtype = _infer_pyarrow_dtype (input_array , orig_inferred_dtype )
1169
+
1181
1170
if (
1182
1171
base_dtype .kind == "O" # type: ignore[union-attr]
1183
1172
and input_array .size > 0
@@ -1188,8 +1177,10 @@ def convert_dtypes(
1188
1177
pa_type = pa .null ()
1189
1178
else :
1190
1179
pa_type = to_pyarrow_type (base_dtype )
1180
+
1191
1181
if pa_type is not None :
1192
1182
inferred_dtype = ArrowDtype (pa_type )
1183
+
1193
1184
elif dtype_backend == "numpy_nullable" and isinstance (inferred_dtype , ArrowDtype ):
1194
1185
# GH 53648
1195
1186
inferred_dtype = _arrow_dtype_mapping ()[inferred_dtype .pyarrow_dtype ]
@@ -1199,6 +1190,35 @@ def convert_dtypes(
1199
1190
return inferred_dtype # type: ignore[return-value]
1200
1191
1201
1192
1193
+ def _infer_pyarrow_dtype (
1194
+ input_array : ArrayLike ,
1195
+ inferred_dtype : str ,
1196
+ ) -> DtypeObj :
1197
+ if inferred_dtype not in ["time" , "date" , "decimal" , "bytes" ]:
1198
+ return input_array .dtype
1199
+
1200
+ # For a limited set of dtype
1201
+ # Let pyarrow infer dtype from input_array
1202
+ import pyarrow as pa
1203
+ from pyarrow import (
1204
+ ArrowInvalid ,
1205
+ ArrowMemoryError ,
1206
+ ArrowNotImplementedError ,
1207
+ )
1208
+
1209
+ try :
1210
+ pyarrow_array = pa .array (input_array )
1211
+ return ArrowDtype (pyarrow_array .type )
1212
+ except (
1213
+ TypeError ,
1214
+ ValueError ,
1215
+ ArrowInvalid ,
1216
+ ArrowMemoryError ,
1217
+ ArrowNotImplementedError ,
1218
+ ):
1219
+ return input_array .dtype
1220
+
1221
+
1202
1222
def maybe_infer_to_datetimelike (
1203
1223
value : npt .NDArray [np .object_ ],
1204
1224
) -> np .ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray :
0 commit comments