Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 4 additions & 30 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,16 @@ def _disallow_mismatched_datetimelike(value, dtype: DtypeObj) -> None:


@overload
def maybe_downcast_to_dtype(
result: np.ndarray, dtype: str | np.dtype
) -> np.ndarray: ...
def maybe_downcast_to_dtype(result: np.ndarray, dtype: np.dtype) -> np.ndarray: ...


@overload
def maybe_downcast_to_dtype(
result: ExtensionArray, dtype: str | np.dtype
) -> ArrayLike: ...
result: ExtensionArray, dtype: np.dtype
) -> ExtensionArray: ...


def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLike:
def maybe_downcast_to_dtype(result: ArrayLike, dtype: np.dtype) -> ArrayLike:
"""
try to cast to the specified dtype (e.g. convert back to bool/int
or could be an astype of float64->float32
Expand All @@ -266,30 +264,6 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
result = result._values
do_round = False

if isinstance(dtype, str):
if dtype == "infer":
inferred_type = lib.infer_dtype(result, skipna=False)
if inferred_type == "boolean":
dtype = "bool"
elif inferred_type == "integer":
dtype = "int64"
elif inferred_type == "datetime64":
dtype = "datetime64[ns]"
elif inferred_type in ["timedelta", "timedelta64"]:
dtype = "timedelta64[ns]"

# try to upcast here
elif inferred_type == "floating":
dtype = "int64"
if issubclass(result.dtype.type, np.number):
do_round = True

else:
# TODO: complex? what if result is already non-object?
dtype = "object"

dtype = np.dtype(dtype)

if not isinstance(dtype, np.dtype):
# enforce our signature annotation
raise TypeError(dtype) # pragma: no cover
Expand Down
47 changes: 2 additions & 45 deletions pandas/tests/dtypes/cast/test_downcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,20 @@

from pandas import (
Series,
Timedelta,
)
import pandas._testing as tm


@pytest.mark.parametrize(
"arr,dtype,expected",
[
(
np.array([8.5, 8.6, 8.7, 8.8, 8.9999999999995]),
"infer",
np.array([8.5, 8.6, 8.7, 8.8, 8.9999999999995]),
),
(
np.array([8.0, 8.0, 8.0, 8.0, 8.9999999999995]),
"infer",
np.array([8, 8, 8, 8, 9], dtype=np.int64),
),
(
np.array([8.0, 8.0, 8.0, 8.0, 9.0000000000005]),
"infer",
np.array([8, 8, 8, 8, 9], dtype=np.int64),
),
(
# This is a judgement call, but we do _not_ downcast Decimal
# objects
np.array([decimal.Decimal("0.0")]),
"int64",
np.dtype("int64"),
np.array([decimal.Decimal("0.0")]),
),
(
# GH#45837
np.array([Timedelta(days=1), Timedelta(days=2)], dtype=object),
"infer",
np.array([1, 2], dtype="m8[D]").astype("m8[ns]"),
),
# TODO: similar for dt64, dt64tz, Period, Interval?
],
)
def test_downcast(arr, expected, dtype):
Expand All @@ -60,26 +37,6 @@ def test_downcast_booleans():
tm.assert_numpy_array_equal(result, expected)


def test_downcast_conversion_no_nan(any_real_numpy_dtype):
dtype = any_real_numpy_dtype
expected = np.array([1, 2])
arr = np.array([1.0, 2.0], dtype=dtype)

result = maybe_downcast_to_dtype(arr, "infer")
tm.assert_almost_equal(result, expected, check_dtype=False)


def test_downcast_conversion_nan(float_numpy_dtype):
dtype = float_numpy_dtype
data = [1.0, 2.0, np.nan]

expected = np.array(data, dtype=dtype)
arr = np.array(data, dtype=dtype)

result = maybe_downcast_to_dtype(arr, "infer")
tm.assert_almost_equal(result, expected)


def test_downcast_conversion_empty(any_real_numpy_dtype):
dtype = any_real_numpy_dtype
arr = np.array([], dtype=dtype)
Expand All @@ -89,7 +46,7 @@ def test_downcast_conversion_empty(any_real_numpy_dtype):

@pytest.mark.parametrize("klass", [np.datetime64, np.timedelta64])
def test_datetime_likes_nan(klass):
dtype = klass.__name__ + "[ns]"
dtype = np.dtype(klass.__name__ + "[ns]")
arr = np.array([1, 2, np.nan])

exp = np.array([1, 2, klass("NaT")], dtype)
Expand Down
Loading