diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 32f2f6297..a66b96513 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -41,6 +41,7 @@ from pandas.core.base import ( NumListLike, _ListLike, ) +from pandas.core.indexes.category import CategoricalIndex from pandas.core.strings.accessor import StringMethods from typing_extensions import ( Never, @@ -57,6 +58,7 @@ from pandas._typing import ( AnyAll, ArrayLike, AxesData, + CategoryDtypeArg, DropKeep, Dtype, DtypeArg, @@ -228,6 +230,16 @@ class Index(IndexOpsMixin[S1]): tupleize_cols: bool = ..., ) -> TimedeltaIndex: ... @overload + def __new__( + cls, + data: AxesData, + *, + dtype: CategoryDtypeArg, + copy: bool = ..., + name: Hashable = ..., + tupleize_cols: bool = ..., + ) -> CategoricalIndex: ... + @overload def __new__( cls, data: Sequence[Interval[_OrderableT]] | IndexOpsMixin[Interval[_OrderableT]], diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 04b7a712f..123b3f286 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -359,6 +359,16 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = ..., ) -> Series[Timestamp]: ... @overload + def __new__( + cls, + data: _ListLike, + index: AxesData | None = ..., + *, + dtype: CategoryDtypeArg, + name: Hashable = ..., + copy: bool = ..., + ) -> Series[CategoricalDtype]: ... + @overload def __new__( cls, data: PeriodIndex | Sequence[Period], diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index 1a21d9149..ce3296d0d 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -18,6 +18,7 @@ from pandas.core.arrays.interval import IntervalArray from pandas.core.arrays.timedeltas import TimedeltaArray from pandas.core.indexes.base import Index +from pandas.core.indexes.category import CategoricalIndex from typing_extensions import ( Never, assert_type, @@ -1369,6 +1370,12 @@ def test_index_factorize() -> None: check(assert_type(idx_uniques, np_1darray | Index | Categorical), pd.Index) +def test_index_categorical() -> None: + """Test creating an index with Categorical type GH1383.""" + sr = pd.Index([1], dtype="category") + check(assert_type(sr, CategoricalIndex), CategoricalIndex) + + def test_disallow_empty_index() -> None: # From GH 826 if TYPE_CHECKING_INVALID_USAGE: diff --git a/tests/series/test_properties.py b/tests/series/test_properties.py index 72cf432a6..2b76e50db 100644 --- a/tests/series/test_properties.py +++ b/tests/series/test_properties.py @@ -1,6 +1,5 @@ from typing import ( TYPE_CHECKING, - cast, ) import numpy as np @@ -54,12 +53,9 @@ def test_dt_property() -> None: def test_array_property() -> None: """Test that Series.array returns ExtensionArray and its subclasses""" - # casting due to pandas-dev/pandas-stubs#1383 check( assert_type( - cast( - "pd.Series[pd.CategoricalDtype]", pd.Series([1], dtype="category") - ).array, + pd.Series([1], dtype="category").array, pd.Categorical, ), pd.Categorical, diff --git a/tests/series/test_series.py b/tests/series/test_series.py index 12c75f50d..9079708c1 100644 --- a/tests/series/test_series.py +++ b/tests/series/test_series.py @@ -50,6 +50,8 @@ Scalar, ) +from pandas.core.dtypes.dtypes import CategoricalDtype # noqa F401 + from tests import ( PD_LTE_23, TYPE_CHECKING_INVALID_USAGE, @@ -1819,6 +1821,10 @@ def test_categorical_codes(): cat = pd.Categorical(["a", "b", "a"]) check(assert_type(cat.codes, np_1darray[np.signedinteger]), np_1darray[np.int8]) + # GH1383 + sr = pd.Series([1], dtype="category") + check(assert_type(sr, "pd.Series[CategoricalDtype]"), pd.Series, np.integer) + def test_relops() -> None: # GH 175 @@ -2908,8 +2914,6 @@ def test_astype_categorical(cast_arg: CategoryDtypeArg, target_type: type) -> No # pandas category assert_type(s.astype(pd.CategoricalDtype()), "pd.Series[pd.CategoricalDtype]") assert_type(s.astype(cast_arg), "pd.Series[pd.CategoricalDtype]") - # pyarrow dictionary - # assert_type(s.astype("dictionary[pyarrow]"), "pd.Series[Categorical]") @pytest.mark.parametrize("cast_arg, target_type", ASTYPE_OBJECT_ARGS, ids=repr)