|
64 | 64 | is_string_dtype, |
65 | 65 | is_unsigned_integer_dtype, |
66 | 66 | ) |
| 67 | +from pandas.core.dtypes.common import is_timedelta64_dtype |
67 | 68 | from pandas.tests.extension import base |
68 | 69 |
|
69 | 70 | pa = pytest.importorskip("pyarrow") |
@@ -281,15 +282,33 @@ def test_compare_scalar(self, data, comparison_op): |
281 | 282 | @pytest.mark.parametrize("na_action", [None, "ignore"]) |
282 | 283 | def test_map(self, data_missing, na_action): |
283 | 284 | if data_missing.dtype.kind in "mM": |
284 | | - result = pd.Series( |
285 | | - np.asarray( |
286 | | - data_missing.map(lambda x: x, na_action=na_action), dtype="int64" |
287 | | - ) |
288 | | - ) |
289 | | - expected = pd.Series( |
290 | | - data_missing.to_numpy().astype(result.dtype).view("int64") |
291 | | - ) |
292 | | - tm.assert_series_equal(result, expected, check_dtype=False) |
| 285 | + mapped = data_missing.map(lambda x: x, na_action=na_action) |
| 286 | + result = pd.Series(mapped) |
| 287 | + expected = pd.Series(data_missing.to_numpy()) |
| 288 | + |
| 289 | + orig_dtype = expected.dtype |
| 290 | + |
| 291 | + if result.dtype == "float64" and ( |
| 292 | + is_datetime64_any_dtype(orig_dtype) |
| 293 | + or is_timedelta64_dtype(orig_dtype) |
| 294 | + or isinstance(orig_dtype, pd.DatetimeTZDtype) |
| 295 | + ): |
| 296 | + result = result.astype(orig_dtype) |
| 297 | + |
| 298 | + if isinstance(orig_dtype, pd.DatetimeTZDtype): |
| 299 | + pass |
| 300 | + elif is_datetime64_any_dtype(orig_dtype): |
| 301 | + result = result.astype("datetime64[ns]").astype("int64") |
| 302 | + expected = expected.astype("datetime64[ns]").astype("int64") |
| 303 | + result = pd.Series(result) |
| 304 | + expected = pd.Series(expected) |
| 305 | + elif is_timedelta64_dtype(orig_dtype): |
| 306 | + result = result.astype("timedelta64[ns]") |
| 307 | + expected = expected.astype("timedelta64[ns]") |
| 308 | + |
| 309 | + |
| 310 | + tm.assert_series_equal(result, expected, check_dtype=False, check_exact=False) |
| 311 | + |
293 | 312 | else: |
294 | 313 | result = data_missing.map(lambda x: x, na_action=na_action) |
295 | 314 | if data_missing.dtype == "float32[pyarrow]": |
|
0 commit comments