diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index ed3892817..b5fd68e5c 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -2,6 +2,10 @@ from builtins import ( bool as _bool, str as _str, ) +from collections import ( + OrderedDict, + defaultdict, +) from collections.abc import ( Callable, Hashable, @@ -19,6 +23,7 @@ from typing import ( Generic, Literal, NoReturn, + TypeVar, final, overload, ) @@ -165,6 +170,8 @@ from pandas._typing import ( from pandas.io.formats.style import Styler from pandas.plotting import PlotAccessor +_T_MUTABLE_MAPPING = TypeVar("_T_MUTABLE_MAPPING", bound=MutableMapping, covariant=True) + class _iLocIndexerFrame(_iLocIndexer, Generic[_T]): @overload def __getitem__(self, idx: tuple[int, int]) -> Scalar: ... @@ -392,13 +399,21 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): na_value: Scalar = ..., ) -> np.ndarray: ... @overload + def to_dict( + self, + orient: str = ..., + *, + into: type[defaultdict], + index: Literal[True] = ..., + ) -> Never: ... + @overload def to_dict( self, orient: Literal["records"], *, - into: MutableMapping | type[MutableMapping], + into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING], index: Literal[True] = ..., - ) -> list[MutableMapping[Hashable, Any]]: ... + ) -> list[_T_MUTABLE_MAPPING]: ... @overload def to_dict( self, @@ -410,39 +425,47 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @overload def to_dict( self, - orient: Literal["dict", "list", "series", "index"], + orient: Literal["index"], *, - into: MutableMapping | type[MutableMapping], + into: defaultdict, index: Literal[True] = ..., - ) -> MutableMapping[Hashable, Any]: ... + ) -> defaultdict[Hashable, dict[Hashable, Any]]: ... @overload def to_dict( self, - orient: Literal["split", "tight"], + orient: Literal["index"], *, - into: MutableMapping | type[MutableMapping], - index: bool = ..., - ) -> MutableMapping[Hashable, Any]: ... + into: OrderedDict | type[OrderedDict], + index: Literal[True] = ..., + ) -> OrderedDict[Hashable, dict[Hashable, Any]]: ... @overload def to_dict( self, - orient: Literal["dict", "list", "series", "index"] = ..., + orient: Literal["index"], *, - into: MutableMapping | type[MutableMapping], + into: type[MutableMapping], index: Literal[True] = ..., - ) -> MutableMapping[Hashable, Any]: ... + ) -> MutableMapping[Hashable, dict[Hashable, Any]]: ... @overload def to_dict( self, - orient: Literal["split", "tight"] = ..., + orient: Literal["index"], *, - into: MutableMapping | type[MutableMapping], - index: bool = ..., - ) -> MutableMapping[Hashable, Any]: ... + into: type[dict] = ..., + index: Literal[True] = ..., + ) -> dict[Hashable, dict[Hashable, Any]]: ... + @overload + def to_dict( + self, + orient: Literal["dict", "list", "series"] = ..., + *, + into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING], + index: Literal[True] = ..., + ) -> _T_MUTABLE_MAPPING: ... @overload def to_dict( self, - orient: Literal["dict", "list", "series", "index"] = ..., + orient: Literal["dict", "list", "series"] = ..., *, into: type[dict] = ..., index: Literal[True] = ..., @@ -450,11 +473,19 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @overload def to_dict( self, - orient: Literal["split", "tight"] = ..., + orient: Literal["split", "tight"], + *, + into: MutableMapping | type[MutableMapping], + index: bool = ..., + ) -> MutableMapping[str, list]: ... + @overload + def to_dict( + self, + orient: Literal["split", "tight"], *, into: type[dict] = ..., index: bool = ..., - ) -> dict[Hashable, Any]: ... + ) -> dict[str, list]: ... def to_gbq( self, destination_table: str, diff --git a/tests/test_frame.py b/tests/test_frame.py index ffd85d609..1aea31aba 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1,7 +1,11 @@ from __future__ import annotations -from collections import defaultdict +from collections import ( + OrderedDict, + defaultdict, +) from collections.abc import ( + Callable, Hashable, Iterable, Iterator, @@ -20,7 +24,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Generic, TypedDict, TypeVar, @@ -39,6 +42,7 @@ ) import pytest from typing_extensions import ( + Never, TypeAlias, assert_never, assert_type, @@ -3638,33 +3642,86 @@ def test_to_records() -> None: ) -def test_to_dict() -> None: +def test_to_dict_simple() -> None: check(assert_type(DF.to_dict(), dict[Hashable, Any]), dict) - check(assert_type(DF.to_dict("split"), dict[Hashable, Any]), dict) + check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list) + check(assert_type(DF.to_dict("index"), dict[Hashable, dict[Hashable, Any]]), dict) + check(assert_type(DF.to_dict("dict"), dict[Hashable, Any]), dict) + check(assert_type(DF.to_dict("list"), dict[Hashable, Any]), dict) + check(assert_type(DF.to_dict("series"), dict[Hashable, Any]), dict) + check(assert_type(DF.to_dict("split"), dict[str, list]), dict, str) + check(assert_type(DF.to_dict("tight"), dict[str, list]), dict, str) + + if TYPE_CHECKING_INVALID_USAGE: + + def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction] + DF.to_dict( # type: ignore[call-overload] + into=mapping # pyright: ignore[reportArgumentType,reportCallIssue] + ) + + assert_type(DF.to_dict(into=defaultdict), Never) + assert_type(DF.to_dict("records", into=defaultdict), Never) + assert_type(DF.to_dict("index", into=defaultdict), Never) + assert_type(DF.to_dict("dict", into=defaultdict), Never) + assert_type(DF.to_dict("list", into=defaultdict), Never) + assert_type(DF.to_dict("series", into=defaultdict), Never) + assert_type(DF.to_dict("split", into=defaultdict), Never) + assert_type(DF.to_dict("tight", into=defaultdict), Never) + + +def test_to_dict_into_defaultdict() -> None: + """Test DataFrame.to_dict with `into` is an instance of defaultdict[Any, list]""" + + data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]}) + target: defaultdict[Any, list] = defaultdict(list) - target: MutableMapping = defaultdict(list) check( - assert_type(DF.to_dict(into=target), MutableMapping[Hashable, Any]), defaultdict + assert_type(data.to_dict(into=target), defaultdict[Any, list]), + defaultdict, + tuple, ) - target = defaultdict(list) check( - assert_type(DF.to_dict("tight", into=target), MutableMapping[Hashable, Any]), + assert_type( + data.to_dict("index", into=target), + defaultdict[Hashable, dict[Hashable, Any]], + ), defaultdict, ) - target = defaultdict(list) - check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list) + check( + assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]), + defaultdict, + str, + ) + check( + assert_type(data.to_dict("records", into=target), list[defaultdict[Any, list]]), + list, + defaultdict, + ) + + +def test_to_dict_into_ordered_dict() -> None: + """Test DataFrame.to_dict with `into=OrderedDict`""" + + data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]}) + + check(assert_type(data.to_dict(into=OrderedDict), OrderedDict), OrderedDict, tuple) check( assert_type( - DF.to_dict("records", into=target), list[MutableMapping[Hashable, Any]] + data.to_dict("index", into=OrderedDict), + OrderedDict[Hashable, dict[Hashable, Any]], ), + OrderedDict, + ) + check( + assert_type(data.to_dict("tight", into=OrderedDict), MutableMapping[str, list]), + OrderedDict, + str, + ) + check( + assert_type(data.to_dict("records", into=OrderedDict), list[OrderedDict]), list, + OrderedDict, ) - if TYPE_CHECKING_INVALID_USAGE: - - def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction] - DF.to_dict( # type: ignore[call-overload] - into=mapping # pyright: ignore[reportArgumentType,reportCallIssue] - ) def test_neg() -> None: @@ -4111,19 +4168,22 @@ def test_to_dict_index() -> None: assert_type(df.to_dict(orient="series", index=True), dict[Hashable, Any]), dict ) check( - assert_type(df.to_dict(orient="index", index=True), dict[Hashable, Any]), dict + assert_type( + df.to_dict(orient="index", index=True), dict[Hashable, dict[Hashable, Any]] + ), + dict, ) check( - assert_type(df.to_dict(orient="split", index=True), dict[Hashable, Any]), dict + assert_type(df.to_dict(orient="split", index=True), dict[str, list]), dict, str ) check( - assert_type(df.to_dict(orient="tight", index=True), dict[Hashable, Any]), dict + assert_type(df.to_dict(orient="tight", index=True), dict[str, list]), dict, str ) check( - assert_type(df.to_dict(orient="tight", index=False), dict[Hashable, Any]), dict + assert_type(df.to_dict(orient="tight", index=False), dict[str, list]), dict, str ) check( - assert_type(df.to_dict(orient="split", index=False), dict[Hashable, Any]), dict + assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str ) if TYPE_CHECKING_INVALID_USAGE: check(assert_type(df.to_dict(orient="records", index=False), list[dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]