Skip to content

Commit 16d6fcb

Browse files
feat: Add assert_series_equal in testing module (#2983)
--------- Co-authored-by: Dan Redding <[email protected]>
1 parent 849326f commit 16d6fcb

File tree

10 files changed

+809
-6
lines changed

10 files changed

+809
-6
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ repos:
8383
name: don't import from narwhals.dtypes (use `Version.dtypes` instead)
8484
entry: |
8585
(?x)
86-
import\ narwhals.dtypes|
87-
from\ narwhals\ import\ dtypes|
88-
from\ narwhals.dtypes\ import\ [^D_]+|
89-
import\ narwhals.stable.v1.dtypes|
90-
from\ narwhals.stable\.v.\ import\ dtypes|
91-
from\ narwhals.stable\.v.\.dtypes\ import
86+
import\ narwhals(\.stable\.v\d)?\.dtypes|
87+
from\ narwhals(\.stable\.v\d)?\ import\ dtypes|
88+
^from\ narwhals(\.stable\.v\d)?\.dtypes\ import
89+
\ (DType,\ )?
90+
((Datetime|Duration|Enum)(,\ )?)+
91+
((,\ )?DType)?
9292
language: pygrep
9393
files: ^narwhals/
9494
exclude: |

docs/api-reference/testing.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# `narwhals.testing`
2+
3+
::: narwhals.testing
4+
handler: python
5+
options:
6+
members:
7+
- assert_series_equal

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ nav:
6969
- api-reference/dtypes.md
7070
- api-reference/exceptions.md
7171
- api-reference/selectors.md
72+
- api-reference/testing.md
7273
- api-reference/typing.md
7374
- api-reference/utils.md
7475
- This: this.md

narwhals/testing/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from narwhals.testing.asserts.series import assert_series_equal
4+
5+
__all__ = ("assert_series_equal",)

narwhals/testing/asserts/__init__.py

Whitespace-only changes.

narwhals/testing/asserts/series.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
from __future__ import annotations
2+
3+
from functools import partial
4+
from typing import TYPE_CHECKING, Any, Callable
5+
6+
from narwhals._utils import qualified_type_name, zip_strict
7+
from narwhals.dependencies import is_narwhals_series
8+
from narwhals.dtypes import Array, Boolean, Categorical, List, String, Struct
9+
from narwhals.functions import new_series
10+
from narwhals.testing.asserts.utils import raise_series_assertion_error
11+
12+
if TYPE_CHECKING:
13+
from typing_extensions import TypeAlias
14+
15+
from narwhals.series import Series
16+
from narwhals.typing import IntoSeriesT, SeriesT
17+
18+
CheckFn: TypeAlias = Callable[[Series[Any], Series[Any]], None]
19+
20+
21+
def assert_series_equal(
22+
left: Series[IntoSeriesT],
23+
right: Series[IntoSeriesT],
24+
*,
25+
check_dtypes: bool = True,
26+
check_names: bool = True,
27+
check_order: bool = True,
28+
check_exact: bool = False,
29+
rel_tol: float = 1e-05,
30+
abs_tol: float = 1e-08,
31+
categorical_as_str: bool = False,
32+
) -> None:
33+
"""Assert that the left and right Series are equal.
34+
35+
Raises a detailed `AssertionError` if the Series differ.
36+
This function is intended for use in unit tests.
37+
38+
Arguments:
39+
left: The first Series to compare.
40+
right: The second Series to compare.
41+
check_dtypes: Requires data types to match.
42+
check_names: Requires names to match.
43+
check_order: Requires elements to appear in the same order.
44+
check_exact: Requires float values to match exactly. If set to `False`, values are
45+
considered equal when within tolerance of each other (see `rel_tol` and
46+
`abs_tol`). Only affects columns with a Float data type.
47+
rel_tol: Relative tolerance for inexact checking, given as a fraction of the
48+
values in `right`.
49+
abs_tol: Absolute tolerance for inexact checking.
50+
categorical_as_str: Cast categorical columns to string before comparing.
51+
Enabling this helps compare columns that do not share the same string cache.
52+
53+
Examples:
54+
>>> import pandas as pd
55+
>>> import narwhals as nw
56+
>>> from narwhals.testing import assert_series_equal
57+
>>> s1 = nw.from_native(pd.Series([1, 2, 3]), series_only=True)
58+
>>> s2 = nw.from_native(pd.Series([1, 5, 3]), series_only=True)
59+
>>> assert_series_equal(s1, s2) # doctest: +ELLIPSIS
60+
Traceback (most recent call last):
61+
...
62+
AssertionError: Series are different (exact value mismatch)
63+
[left]:
64+
┌───────────────┐
65+
|Narwhals Series|
66+
|---------------|
67+
| 0 1 |
68+
| 1 2 |
69+
| 2 3 |
70+
| dtype: int64 |
71+
└───────────────┘
72+
[right]:
73+
┌───────────────┐
74+
|Narwhals Series|
75+
|---------------|
76+
| 0 1 |
77+
| 1 5 |
78+
| 2 3 |
79+
| dtype: int64 |
80+
└───────────────┘
81+
"""
82+
__tracebackhide__ = True
83+
84+
if any(not is_narwhals_series(obj) for obj in (left, right)):
85+
msg = (
86+
"Expected `narwhals.Series` instance, found:\n"
87+
f"[left]: {qualified_type_name(type(left))}\n"
88+
f"[right]: {qualified_type_name(type(right))}\n\n"
89+
"Hint: Use `nw.from_native(obj, series_only=True) to convert each native "
90+
"object into a `narwhals.Series` first."
91+
)
92+
raise TypeError(msg)
93+
94+
_check_metadata(left, right, check_dtypes=check_dtypes, check_names=check_names)
95+
96+
if not check_order:
97+
if left.dtype.is_nested():
98+
msg = "`check_order=False` is not supported (yet) with nested data type."
99+
raise NotImplementedError(msg)
100+
left, right = left.sort(), right.sort()
101+
102+
left_vals, right_vals = _check_null_values(left, right)
103+
104+
if check_exact or not left.dtype.is_float():
105+
_check_exact_values(
106+
left_vals,
107+
right_vals,
108+
check_dtypes=check_dtypes,
109+
check_exact=check_exact,
110+
rel_tol=rel_tol,
111+
abs_tol=abs_tol,
112+
categorical_as_str=categorical_as_str,
113+
)
114+
else:
115+
_check_approximate_values(left_vals, right_vals, rel_tol=rel_tol, abs_tol=abs_tol)
116+
117+
118+
def _check_metadata(
119+
left: SeriesT, right: SeriesT, *, check_dtypes: bool, check_names: bool
120+
) -> None:
121+
"""Check metadata information: implementation, length, dtype, and names."""
122+
left_impl, right_impl = left.implementation, right.implementation
123+
if left_impl != right_impl:
124+
raise_series_assertion_error("implementation mismatch", left_impl, right_impl)
125+
126+
left_len, right_len = len(left), len(right)
127+
if left_len != right_len:
128+
raise_series_assertion_error("length mismatch", left_len, right_len)
129+
130+
left_dtype, right_dtype = left.dtype, right.dtype
131+
if check_dtypes and left_dtype != right_dtype:
132+
raise_series_assertion_error("dtype mismatch", left_dtype, right_dtype)
133+
134+
left_name, right_name = left.name, right.name
135+
if check_names and left_name != right_name:
136+
raise_series_assertion_error("name mismatch", left_name, right_name)
137+
138+
139+
def _check_null_values(left: SeriesT, right: SeriesT) -> tuple[SeriesT, SeriesT]:
140+
"""Check null value consistency and return non-null values."""
141+
left_null_count, right_null_count = left.null_count(), right.null_count()
142+
left_null_mask, right_null_mask = left.is_null(), right.is_null()
143+
144+
if left_null_count != right_null_count or (left_null_mask != right_null_mask).any():
145+
raise_series_assertion_error(
146+
"null value mismatch", left_null_count, right_null_count
147+
)
148+
149+
return left.filter(~left_null_mask), right.filter(~right_null_mask)
150+
151+
152+
def _check_exact_values(
153+
left: SeriesT,
154+
right: SeriesT,
155+
*,
156+
check_dtypes: bool,
157+
check_exact: bool,
158+
rel_tol: float,
159+
abs_tol: float,
160+
categorical_as_str: bool,
161+
) -> None:
162+
"""Check exact value equality for various data types."""
163+
left_impl = left.implementation
164+
left_dtype, right_dtype = left.dtype, right.dtype
165+
166+
is_not_equal_mask: Series[Any]
167+
if left_dtype.is_numeric():
168+
# For _all_ numeric dtypes, we can use `is_close` with 0-tolerances to handle
169+
# inf and nan values out of the box.
170+
is_not_equal_mask = ~left.is_close(right, rel_tol=0, abs_tol=0, nans_equal=True)
171+
elif (
172+
isinstance(left_dtype, (Array, List)) and isinstance(right_dtype, (Array, List))
173+
) and left_dtype == right_dtype:
174+
check_fn = partial(
175+
assert_series_equal,
176+
check_dtypes=check_dtypes,
177+
check_names=False,
178+
check_order=True,
179+
check_exact=check_exact,
180+
rel_tol=rel_tol,
181+
abs_tol=abs_tol,
182+
categorical_as_str=categorical_as_str,
183+
)
184+
_check_list_like(left, right, left_dtype, right_dtype, check_fn=check_fn)
185+
# If `_check_list_like` didn't raise, then every nested element is equal
186+
is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)
187+
elif isinstance(left_dtype, Struct) and isinstance(right_dtype, Struct):
188+
check_fn = partial(
189+
assert_series_equal,
190+
check_dtypes=True,
191+
check_names=True,
192+
check_order=True,
193+
check_exact=check_exact,
194+
rel_tol=rel_tol,
195+
abs_tol=abs_tol,
196+
categorical_as_str=categorical_as_str,
197+
)
198+
_check_struct(left, right, left_dtype, right_dtype, check_fn=check_fn)
199+
# If `_check_struct` didn't raise, then every nested element is equal
200+
is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)
201+
elif isinstance(left_dtype, Categorical) and isinstance(right_dtype, Categorical):
202+
# If `_check_categorical` didn't raise, then the categories sources/encodings are
203+
# the same, and we can use equality
204+
_not_equal = _check_categorical(
205+
left, right, categorical_as_str=categorical_as_str
206+
)
207+
is_not_equal_mask = new_series(
208+
"", [_not_equal], dtype=Boolean(), backend=left_impl
209+
)
210+
else:
211+
is_not_equal_mask = left != right
212+
213+
if is_not_equal_mask.any():
214+
raise_series_assertion_error("exact value mismatch", left, right)
215+
216+
217+
def _check_approximate_values(
218+
left: SeriesT, right: SeriesT, *, rel_tol: float, abs_tol: float
219+
) -> None:
220+
"""Check approximate value equality with tolerance."""
221+
is_not_close_mask = ~left.is_close(
222+
right, rel_tol=rel_tol, abs_tol=abs_tol, nans_equal=True
223+
)
224+
225+
if is_not_close_mask.any():
226+
raise_series_assertion_error(
227+
"values not within tolerance",
228+
left.filter(is_not_close_mask),
229+
right.filter(is_not_close_mask),
230+
)
231+
232+
233+
def _check_list_like(
234+
left_vals: SeriesT,
235+
right_vals: SeriesT,
236+
left_dtype: List | Array,
237+
right_dtype: List | Array,
238+
check_fn: CheckFn,
239+
) -> None:
240+
# Check row by row after transforming each array/list into a new series.
241+
# Notice that order within the array/list must be the same, regardless of
242+
# `check_order` value at the top level.
243+
impl = left_vals.implementation
244+
try:
245+
for left_val, right_val in zip_strict(left_vals, right_vals):
246+
check_fn(
247+
new_series("", values=left_val, dtype=left_dtype.inner, backend=impl),
248+
new_series("", values=right_val, dtype=right_dtype.inner, backend=impl),
249+
)
250+
except AssertionError:
251+
raise_series_assertion_error("nested value mismatch", left_vals, right_vals)
252+
253+
254+
def _check_struct(
255+
left_vals: SeriesT,
256+
right_vals: SeriesT,
257+
left_dtype: Struct,
258+
right_dtype: Struct,
259+
check_fn: CheckFn,
260+
) -> None:
261+
# Check field by field as a separate column.
262+
# Notice that for struct's polars raises if:
263+
# * field names are different but values are equal
264+
# * dtype differs, regardless of `check_dtypes=False`
265+
# * order applies only at top level
266+
try:
267+
for left_field, right_field in zip_strict(left_dtype.fields, right_dtype.fields):
268+
check_fn(
269+
left_vals.struct.field(left_field.name),
270+
right_vals.struct.field(right_field.name),
271+
)
272+
except AssertionError:
273+
raise_series_assertion_error("exact value mismatch", left_vals, right_vals)
274+
275+
276+
def _check_categorical(
277+
left_vals: SeriesT, right_vals: SeriesT, *, categorical_as_str: bool
278+
) -> bool:
279+
"""Try to compare if any element of categorical series' differ.
280+
281+
Inability to compare means that the encoding is different, and an exception is raised.
282+
"""
283+
if categorical_as_str:
284+
left_vals, right_vals = left_vals.cast(String()), right_vals.cast(String())
285+
286+
try:
287+
return (left_vals != right_vals).any()
288+
except Exception as exc:
289+
msg = "Cannot compare categoricals coming from different sources."
290+
# TODO(FBruzzesi): Improve error message?
291+
raise AssertionError(msg) from exc

narwhals/testing/asserts/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Literal
4+
5+
from narwhals.dependencies import is_narwhals_series
6+
7+
if TYPE_CHECKING:
8+
from typing_extensions import Never, TypeAlias
9+
10+
# NOTE: This alias is created to facilitate autocomplete. Feel free to extend it as
11+
# you please when adding a new feature.
12+
# See: https://github.com/narwhals-dev/narwhals/pull/2983#discussion_r2337548736
13+
SeriesDetail: TypeAlias = Literal[
14+
"implementation mismatch",
15+
"length mismatch",
16+
"dtype mismatch",
17+
"name mismatch",
18+
"null value mismatch",
19+
"exact value mismatch",
20+
"values not within tolerance",
21+
"nested value mismatch",
22+
]
23+
24+
25+
def raise_assertion_error(
26+
objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None
27+
) -> Never:
28+
"""Raise a detailed assertion error."""
29+
__tracebackhide__ = True
30+
31+
trailing_left = "\n" if is_narwhals_series(left) else " "
32+
trailing_right = "\n" if is_narwhals_series(right) else " "
33+
34+
msg = (
35+
f"{objects} are different ({detail})\n"
36+
f"[left]:{trailing_left}{left}\n"
37+
f"[right]:{trailing_right}{right}"
38+
)
39+
raise AssertionError(msg) from cause
40+
41+
42+
def raise_series_assertion_error(
43+
detail: SeriesDetail, left: Any, right: Any, *, cause: Exception | None = None
44+
) -> Never:
45+
raise_assertion_error("Series", detail, left, right, cause=cause)

tests/testing/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)