-
-
Notifications
You must be signed in to change notification settings - Fork 19.1k
26302 add typing to assert star equal funcs #29364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
3ca64e2
17caa42
9d56cfc
f7392dd
daa9c87
0c2f692
657b3de
f9f4e7c
eb4c25a
7a2ae46
a735027
6d38c1b
e3b63c6
7268afa
9222488
ab95c8a
4a19f0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from shutil import rmtree | ||
import string | ||
import tempfile | ||
from typing import Union, cast | ||
from typing import Optional, Union, cast | ||
import warnings | ||
import zipfile | ||
|
||
|
@@ -53,6 +53,7 @@ | |
Series, | ||
bdate_range, | ||
) | ||
from pandas._typing import AnyArrayLike | ||
from pandas.core.algorithms import take_1d | ||
from pandas.core.arrays import ( | ||
DatetimeArray, | ||
|
@@ -806,8 +807,12 @@ def assert_is_sorted(seq): | |
|
||
|
||
def assert_categorical_equal( | ||
left, right, check_dtype=True, check_category_order=True, obj="Categorical" | ||
): | ||
left: Categorical, | ||
right: Categorical, | ||
check_dtype: bool = True, | ||
check_category_order: bool = True, | ||
obj: str = "Categorical", | ||
) -> None: | ||
"""Test that Categoricals are equivalent. | ||
|
||
Parameters | ||
|
@@ -852,7 +857,12 @@ def assert_categorical_equal( | |
assert_attr_equal("ordered", left, right, obj=obj) | ||
|
||
|
||
def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"): | ||
def assert_interval_array_equal( | ||
left: IntervalArray, | ||
right: IntervalArray, | ||
exact: str = "equiv", | ||
obj: str = "IntervalArray", | ||
) -> None: | ||
"""Test that two IntervalArrays are equivalent. | ||
|
||
Parameters | ||
|
@@ -867,8 +877,6 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") | |
Specify object name being compared, internally used to show appropriate | ||
assertion message | ||
""" | ||
_check_isinstance(left, right, IntervalArray) | ||
|
||
assert_index_equal( | ||
left.left, right.left, exact=exact, obj="{obj}.left".format(obj=obj) | ||
) | ||
|
@@ -878,7 +886,9 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray") | |
assert_attr_equal("closed", left, right, obj=obj) | ||
|
||
|
||
def assert_period_array_equal(left, right, obj="PeriodArray"): | ||
def assert_period_array_equal( | ||
left: PeriodArray, right: PeriodArray, obj: str = "PeriodArray" | ||
) -> None: | ||
_check_isinstance(left, right, PeriodArray) | ||
|
||
assert_numpy_array_equal( | ||
|
@@ -887,7 +897,9 @@ def assert_period_array_equal(left, right, obj="PeriodArray"): | |
assert_attr_equal("freq", left, right, obj=obj) | ||
|
||
|
||
def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | ||
def assert_datetime_array_equal( | ||
left: DatetimeArray, right: DatetimeArray, obj: str = "DatetimeArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, DatetimeArray) | ||
|
||
|
@@ -896,7 +908,9 @@ def assert_datetime_array_equal(left, right, obj="DatetimeArray"): | |
assert_attr_equal("tz", left, right, obj=obj) | ||
|
||
|
||
def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"): | ||
def assert_timedelta_array_equal( | ||
left: TimedeltaArray, right: TimedeltaArray, obj: str = "TimedeltaArray" | ||
) -> None: | ||
__tracebackhide__ = True | ||
_check_isinstance(left, right, TimedeltaArray) | ||
assert_numpy_array_equal(left._data, right._data, obj="{obj}._data".format(obj=obj)) | ||
|
@@ -931,14 +945,14 @@ def raise_assert_detail(obj, message, left, right, diff=None): | |
|
||
|
||
def assert_numpy_array_equal( | ||
left, | ||
right, | ||
strict_nan=False, | ||
check_dtype=True, | ||
err_msg=None, | ||
check_same=None, | ||
obj="numpy array", | ||
): | ||
left: np.ndarray, | ||
right: np.ndarray, | ||
strict_nan: bool = False, | ||
check_dtype: bool = True, | ||
err_msg: Optional[str] = None, | ||
check_same: Optional[str] = None, | ||
obj: str = "numpy array", | ||
) -> None: | ||
""" Checks that 'np.ndarray' is equivalent | ||
|
||
Parameters | ||
|
@@ -1067,18 +1081,18 @@ def assert_extension_array_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_series_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_series_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
obj="Series", | ||
): | ||
left: Series, | ||
right: Series, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_series_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
obj: str = "Series", | ||
) -> None: | ||
""" | ||
Check that left and right Series are equal. | ||
|
||
|
@@ -1186,7 +1200,9 @@ def assert_series_equal( | |
check_dtype=check_dtype, | ||
) | ||
elif is_interval_dtype(left) or is_interval_dtype(right): | ||
assert_interval_array_equal(left.array, right.array) | ||
left_array = cast(IntervalArray, left.array) | ||
jreback marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
right_array = cast(IntervalArray, right.array) | ||
assert_interval_array_equal(left_array, right_array) | ||
elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype): | ||
# .values is an ndarray, but ._values is the ExtensionArray. | ||
# TODO: Use .array | ||
|
@@ -1221,21 +1237,21 @@ def assert_series_equal( | |
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_frame_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_index_type="equiv", | ||
check_column_type="equiv", | ||
check_frame_type=True, | ||
check_less_precise=False, | ||
check_names=True, | ||
by_blocks=False, | ||
check_exact=False, | ||
check_datetimelike_compat=False, | ||
check_categorical=True, | ||
check_like=False, | ||
obj="DataFrame", | ||
): | ||
left: DataFrame, | ||
right: DataFrame, | ||
check_dtype: bool = True, | ||
check_index_type: str = "equiv", | ||
check_column_type: str = "equiv", | ||
check_frame_type: bool = True, | ||
check_less_precise: bool = False, | ||
check_names: bool = True, | ||
by_blocks: bool = False, | ||
check_exact: bool = False, | ||
check_datetimelike_compat: bool = False, | ||
check_categorical: bool = True, | ||
check_like: bool = False, | ||
obj: str = "DataFrame", | ||
) -> None: | ||
""" | ||
Check that left and right DataFrame are equal. | ||
|
||
|
@@ -1403,7 +1419,11 @@ def assert_frame_equal( | |
) | ||
|
||
|
||
def assert_equal(left, right, **kwargs): | ||
def assert_equal( | ||
left: Union[DataFrame, AnyArrayLike], | ||
WillAyd marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
right: Union[DataFrame, AnyArrayLike], | ||
**kwargs | ||
) -> None: | ||
""" | ||
Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. | ||
|
||
|
@@ -1415,27 +1435,36 @@ def assert_equal(left, right, **kwargs): | |
""" | ||
__tracebackhide__ = True | ||
|
||
if isinstance(left, pd.Index): | ||
if isinstance(left, Index): | ||
right = cast(Index, right) | ||
|
||
assert_index_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.Series): | ||
elif isinstance(left, Series): | ||
right = cast(Series, right) | ||
|
||
assert_series_equal(left, right, **kwargs) | ||
elif isinstance(left, pd.DataFrame): | ||
elif isinstance(left, DataFrame): | ||
right = cast(DataFrame, right) | ||
assert_frame_equal(left, right, **kwargs) | ||
elif isinstance(left, IntervalArray): | ||
right = cast(IntervalArray, right) | ||
assert_interval_array_equal(left, right, **kwargs) | ||
elif isinstance(left, PeriodArray): | ||
right = cast(PeriodArray, right) | ||
assert_period_array_equal(left, right, **kwargs) | ||
elif isinstance(left, DatetimeArray): | ||
right = cast(DatetimeArray, right) | ||
assert_datetime_array_equal(left, right, **kwargs) | ||
elif isinstance(left, TimedeltaArray): | ||
right = cast(TimedeltaArray, right) | ||
assert_timedelta_array_equal(left, right, **kwargs) | ||
elif isinstance(left, ExtensionArray): | ||
right = cast(ExtensionArray, right) | ||
assert_extension_array_equal(left, right, **kwargs) | ||
elif isinstance(left, np.ndarray): | ||
right = cast(np.ndarray, right) | ||
assert_numpy_array_equal(left, right, **kwargs) | ||
elif isinstance(left, str): | ||
assert kwargs == {} | ||
return left == right | ||
assert left == right | ||
WillAyd marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
else: | ||
raise NotImplementedError(type(left)) | ||
|
||
|
@@ -1497,12 +1526,12 @@ def to_array(obj): | |
|
||
|
||
def assert_sp_array_equal( | ||
left, | ||
right, | ||
check_dtype=True, | ||
check_kind=True, | ||
check_fill_value=True, | ||
consolidate_block_indices=False, | ||
left: pd.SparseArray, | ||
right: pd.SparseArray, | ||
check_dtype: bool = True, | ||
check_kind: bool = True, | ||
check_fill_value: bool = True, | ||
consolidate_block_indices: bool = False, | ||
): | ||
"""Check that the left and right SparseArray are equal. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AnyArrayLike resolves to Any.
If it adds value in the form of code documentation then OK, but mypy is effectively not checking these annotations.