Skip to content

Commit e94997a

Browse files
committed
Manage IntervalDtype
1 parent 547662d commit e94997a

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

pandas/core/algorithms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
BaseMaskedDtype,
6363
CategoricalDtype,
6464
ExtensionDtype,
65+
IntervalDtype,
6566
NumpyEADtype,
6667
)
6768
from pandas.core.dtypes.generic import (
@@ -1665,7 +1666,11 @@ def map_array(
16651666
from pandas import Series
16661667

16671668
if len(mapper) == 0:
1668-
if is_extension_array_dtype(arr.dtype) and arr.dtype.na_value is NA:
1669+
if (
1670+
is_extension_array_dtype(arr.dtype)
1671+
and not isinstance(arr.dtype, IntervalDtype)
1672+
and arr.dtype.na_value is NA
1673+
):
16691674
mapper = Series(mapper, dtype=arr.dtype)
16701675
else:
16711676
mapper = Series(mapper, dtype=np.float64)

pandas/core/arrays/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
)
4949
from pandas.core.dtypes.dtypes import (
5050
ExtensionDtype,
51-
IntervalDtype,
5251
NumpyEADtype,
5352
)
5453
from pandas.core.dtypes.generic import (
@@ -2342,12 +2341,13 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
23422341
a MultiIndex will be returned.
23432342
"""
23442343
result = map_array(self, mapper, na_action=na_action)
2345-
if isinstance(self.dtype, NumpyEADtype):
2346-
return pd_array(result, dtype=NumpyEADtype(result.dtype))
2347-
if isinstance(self.dtype, IntervalDtype):
2348-
return result
2344+
if isinstance(result, ExtensionArray):
2345+
if isinstance(self.dtype, NumpyEADtype):
2346+
return pd_array(result, dtype=NumpyEADtype(result.dtype))
2347+
else:
2348+
return result
23492349
elif isinstance(result, np.ndarray):
2350-
return pd_array(result)
2350+
return pd_array(result, result.dtype)
23512351
else:
23522352
return result
23532353

pandas/tests/series/methods/test_map.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from pandas.core.dtypes.common import is_extension_array_dtype
12+
from pandas.core.dtypes.dtypes import IntervalDtype
1213

1314
import pandas as pd
1415
from pandas import (
@@ -237,8 +238,11 @@ def test_map_empty(request, index):
237238

238239
s = Series(index)
239240
result = s.map({})
240-
241-
if is_extension_array_dtype(s.dtype) and s.dtype.na_value is pd.NA:
241+
if (
242+
is_extension_array_dtype(s.dtype)
243+
and not isinstance(s.dtype, IntervalDtype)
244+
and s.dtype.na_value is pd.NA
245+
):
242246
na_value = s.dtype.na_value
243247
dtype = s.dtype
244248
else:

0 commit comments

Comments
 (0)