Skip to content

Commit 6ba1942

Browse files
committed
handling tz aware ts explicitly, preventing re-deriving of ArrowDtype. Moving test to appropriate file.
1 parent f1a0ede commit 6ba1942

File tree

4 files changed

+38
-27
lines changed

4 files changed

+38
-27
lines changed

pandas/core/dtypes/cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def convert_dtypes(
11131113
else:
11141114
inferred_dtype = input_array.dtype
11151115

1116-
if dtype_backend == "pyarrow":
1116+
if dtype_backend == "pyarrow" and not isinstance(inferred_dtype, ArrowDtype):
11171117
from pandas.core.arrays.arrow.array import to_pyarrow_type
11181118
from pandas.core.arrays.string_ import StringDtype
11191119

pandas/core/dtypes/dtypes.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,18 +2275,31 @@ def name(self) -> str: # type: ignore[override]
22752275
@cache_readonly
22762276
def numpy_dtype(self) -> np.dtype:
22772277
"""Return an instance of the related numpy dtype."""
2278-
# For string-like arrow dtypes, pa.string().to_pandas_dtype() = object
2279-
# so we handle them explicitly.
2280-
if pa.types.is_string(self.pyarrow_dtype) or pa.types.is_large_string(
2281-
self.pyarrow_dtype
2282-
):
2278+
pa_type = self.pyarrow_dtype
2279+
2280+
# handle tz-aware timestamps
2281+
if pa.types.is_timestamp(pa_type):
2282+
if pa_type.tz is not None:
2283+
# preserve tz by NOT calling numpy_dtype for this dtype.
2284+
return np.dtype("datetime64[ns]")
2285+
else:
2286+
# For tz-naive timestamps, just return the corresponding unit
2287+
return np.dtype(f"datetime64[{pa_type.unit}]")
2288+
2289+
if pa.types.is_duration(pa_type):
2290+
return np.dtype(f"timedelta64[{pa_type.unit}]")
2291+
2292+
if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
22832293
return np.dtype(str)
22842294

22852295
try:
2286-
np_dtype = self.pyarrow_dtype.to_pandas_dtype()
2296+
np_dtype = pa_type.to_pandas_dtype()
2297+
if isinstance(np_dtype, DatetimeTZDtype):
2298+
# In theory we shouldn't get here for tz-aware arrow timestamps
2299+
# if we've handled them above. This is a fallback.
2300+
return np.dtype("datetime64[ns]")
22872301
return np.dtype(np_dtype)
22882302
except (NotImplementedError, TypeError):
2289-
# Fallback if something unexpected happens
22902303
return np.dtype(object)
22912304

22922305
@cache_readonly

pandas/tests/dtypes/test_dtypes.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,25 +1103,6 @@ def test_update_dtype_errors(self, bad_dtype):
11031103
with pytest.raises(ValueError, match=msg):
11041104
dtype.update_dtype(bad_dtype)
11051105

1106-
1107-
class TestArrowDtype:
1108-
@pytest.mark.parametrize(
1109-
"tz", ["UTC", "America/New_York", "Europe/London", "Asia/Tokyo"]
1110-
)
1111-
def test_pyarrow_timestamp_tz_preserved(self, tz):
1112-
pytest.importorskip("pyarrow")
1113-
s = Series(
1114-
pd.to_datetime(range(5), unit="h", utc=True).tz_convert(tz),
1115-
dtype=f"timestamp[ns, tz={tz}][pyarrow]",
1116-
)
1117-
1118-
result = s.convert_dtypes(dtype_backend="pyarrow")
1119-
assert result.dtype == s.dtype, f"Expected {s.dtype}, got {result.dtype}"
1120-
1121-
assert str(result.iloc[0].tzinfo) == str(s.iloc[0].tzinfo)
1122-
tm.assert_series_equal(result, s)
1123-
1124-
11251106
@pytest.mark.parametrize(
11261107
"dtype", [CategoricalDtype, IntervalDtype, DatetimeTZDtype, PeriodDtype]
11271108
)

pandas/tests/extension/test_arrow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
import pandas as pd
5353
import pandas._testing as tm
54+
from pandas import Series
5455
from pandas.api.extensions import no_default
5556
from pandas.api.types import (
5657
is_bool_dtype,
@@ -3505,3 +3506,19 @@ def test_map_numeric_na_action():
35053506
result = ser.map(lambda x: 42, na_action="ignore")
35063507
expected = pd.Series([42.0, 42.0, np.nan], dtype="float64")
35073508
tm.assert_series_equal(result, expected)
3509+
3510+
3511+
@pytest.mark.parametrize(
3512+
"tz", ["UTC", "America/New_York", "Europe/London", "Asia/Tokyo"]
3513+
)
3514+
def test_pyarrow_timestamp_tz_preserved(tz):
3515+
s = Series(
3516+
pd.to_datetime(range(5), unit="h", utc=True).tz_convert(tz),
3517+
dtype=f"timestamp[ns, tz={tz}][pyarrow]"
3518+
)
3519+
3520+
result = s.convert_dtypes(dtype_backend="pyarrow")
3521+
assert result.dtype == s.dtype, f"Expected {s.dtype}, got {result.dtype}"
3522+
3523+
assert str(result.iloc[0].tzinfo) == str(s.iloc[0].tzinfo)
3524+
tm.assert_series_equal(result, s)

0 commit comments

Comments
 (0)