Skip to content

Commit 9b01de4

Browse files
committed
BUG: Preserve timezone in numpy_dtype for ArrowDtype
1 parent 8a286fa commit 9b01de4

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

pandas/core/dtypes/dtypes.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,21 +2276,17 @@ def name(self) -> str: # type: ignore[override]
22762276
def numpy_dtype(self) -> np.dtype:
22772277
"""Return an instance of the related numpy dtype"""
22782278
if pa.types.is_timestamp(self.pyarrow_dtype):
2279-
# pa.timestamp(unit).to_pandas_dtype() returns ns units
2280-
# regardless of the pyarrow timestamp units.
2281-
# This can be removed if/when pyarrow addresses it:
2282-
# https://github.com/apache/arrow/issues/34462
2279+
# Preserve timezone information if present
2280+
if self.pyarrow_dtype.tz is not None:
2281+
# Use PyArrow's to_pandas_dtype method for timezone-aware types
2282+
return self.pyarrow_dtype.to_pandas_dtype()
2283+
# Fall back to naive datetime64 for timezone-naive timestamps
22832284
return np.dtype(f"datetime64[{self.pyarrow_dtype.unit}]")
22842285
if pa.types.is_duration(self.pyarrow_dtype):
2285-
# pa.duration(unit).to_pandas_dtype() returns ns units
2286-
# regardless of the pyarrow duration units
2287-
# This can be removed if/when pyarrow addresses it:
2288-
# https://github.com/apache/arrow/issues/34462
22892286
return np.dtype(f"timedelta64[{self.pyarrow_dtype.unit}]")
22902287
if pa.types.is_string(self.pyarrow_dtype) or pa.types.is_large_string(
22912288
self.pyarrow_dtype
22922289
):
2293-
# pa.string().to_pandas_dtype() = object which we don't want
22942290
return np.dtype(str)
22952291
try:
22962292
return np.dtype(self.pyarrow_dtype.to_pandas_dtype())

pandas/tests/dtypes/test_dtypes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DatetimeTZDtype,
2525
IntervalDtype,
2626
PeriodDtype,
27+
ArrowDtype,
2728
)
2829

2930
import pandas as pd
@@ -1103,6 +1104,32 @@ def test_update_dtype_errors(self, bad_dtype):
11031104
with pytest.raises(ValueError, match=msg):
11041105
dtype.update_dtype(bad_dtype)
11051106

1107+
class TestArrowDtype(Base):
1108+
@pytest.fixture
1109+
def dtype(self):
1110+
"""Fixture for ArrowDtype."""
1111+
import pyarrow as pa
1112+
return ArrowDtype(pa.timestamp("ns", tz="UTC"))
1113+
1114+
def test_numpy_dtype_preserves_timezone(self, dtype):
1115+
# Test timezone-aware timestamp
1116+
assert dtype.numpy_dtype == dtype.pyarrow_dtype.to_pandas_dtype()
1117+
1118+
def test_numpy_dtype_naive_timestamp(self):
1119+
import pyarrow as pa
1120+
arrow_type = pa.timestamp("ns")
1121+
dtype = ArrowDtype(arrow_type)
1122+
assert dtype.numpy_dtype == pa.timestamp("ns").to_pandas_dtype()
1123+
1124+
@pytest.mark.parametrize("tz", ["UTC", "America/New_York", None])
1125+
def test_numpy_dtype_with_varied_timezones(self, tz):
1126+
import pyarrow as pa
1127+
arrow_type = pa.timestamp("ns", tz=tz)
1128+
dtype = ArrowDtype(arrow_type)
1129+
if tz:
1130+
assert dtype.numpy_dtype == arrow_type.to_pandas_dtype()
1131+
else:
1132+
assert dtype.numpy_dtype == pa.timestamp("ns").to_pandas_dtype()
11061133

11071134
@pytest.mark.parametrize(
11081135
"dtype", [CategoricalDtype, IntervalDtype, DatetimeTZDtype, PeriodDtype]

0 commit comments

Comments
 (0)