| 
 | 1 | +from __future__ import annotations  | 
 | 2 | + | 
 | 3 | +from functools import partial  | 
 | 4 | +from typing import TYPE_CHECKING, Any, Callable  | 
 | 5 | + | 
 | 6 | +from narwhals._utils import qualified_type_name, zip_strict  | 
 | 7 | +from narwhals.dependencies import is_narwhals_series  | 
 | 8 | +from narwhals.dtypes import Array, Boolean, Categorical, List, String, Struct  | 
 | 9 | +from narwhals.functions import new_series  | 
 | 10 | +from narwhals.testing.asserts.utils import raise_series_assertion_error  | 
 | 11 | + | 
 | 12 | +if TYPE_CHECKING:  | 
 | 13 | +    from typing_extensions import TypeAlias  | 
 | 14 | + | 
 | 15 | +    from narwhals.series import Series  | 
 | 16 | +    from narwhals.typing import IntoSeriesT, SeriesT  | 
 | 17 | + | 
 | 18 | +    CheckFn: TypeAlias = Callable[[Series[Any], Series[Any]], None]  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +def assert_series_equal(  | 
 | 22 | +    left: Series[IntoSeriesT],  | 
 | 23 | +    right: Series[IntoSeriesT],  | 
 | 24 | +    *,  | 
 | 25 | +    check_dtypes: bool = True,  | 
 | 26 | +    check_names: bool = True,  | 
 | 27 | +    check_order: bool = True,  | 
 | 28 | +    check_exact: bool = False,  | 
 | 29 | +    rel_tol: float = 1e-05,  | 
 | 30 | +    abs_tol: float = 1e-08,  | 
 | 31 | +    categorical_as_str: bool = False,  | 
 | 32 | +) -> None:  | 
 | 33 | +    """Assert that the left and right Series are equal.  | 
 | 34 | +
  | 
 | 35 | +    Raises a detailed `AssertionError` if the Series differ.  | 
 | 36 | +    This function is intended for use in unit tests.  | 
 | 37 | +
  | 
 | 38 | +    Arguments:  | 
 | 39 | +        left: The first Series to compare.  | 
 | 40 | +        right: The second Series to compare.  | 
 | 41 | +        check_dtypes: Requires data types to match.  | 
 | 42 | +        check_names: Requires names to match.  | 
 | 43 | +        check_order: Requires elements to appear in the same order.  | 
 | 44 | +        check_exact: Requires float values to match exactly. If set to `False`, values are  | 
 | 45 | +            considered equal when within tolerance of each other (see `rel_tol` and  | 
 | 46 | +            `abs_tol`). Only affects columns with a Float data type.  | 
 | 47 | +        rel_tol: Relative tolerance for inexact checking, given as a fraction of the  | 
 | 48 | +            values in `right`.  | 
 | 49 | +        abs_tol: Absolute tolerance for inexact checking.  | 
 | 50 | +        categorical_as_str: Cast categorical columns to string before comparing.  | 
 | 51 | +            Enabling this helps compare columns that do not share the same string cache.  | 
 | 52 | +
  | 
 | 53 | +    Examples:  | 
 | 54 | +        >>> import pandas as pd  | 
 | 55 | +        >>> import narwhals as nw  | 
 | 56 | +        >>> from narwhals.testing import assert_series_equal  | 
 | 57 | +        >>> s1 = nw.from_native(pd.Series([1, 2, 3]), series_only=True)  | 
 | 58 | +        >>> s2 = nw.from_native(pd.Series([1, 5, 3]), series_only=True)  | 
 | 59 | +        >>> assert_series_equal(s1, s2)  # doctest: +ELLIPSIS  | 
 | 60 | +        Traceback (most recent call last):  | 
 | 61 | +        ...  | 
 | 62 | +        AssertionError: Series are different (exact value mismatch)  | 
 | 63 | +        [left]:  | 
 | 64 | +        ┌───────────────┐  | 
 | 65 | +        |Narwhals Series|  | 
 | 66 | +        |---------------|  | 
 | 67 | +        | 0    1        |  | 
 | 68 | +        | 1    2        |  | 
 | 69 | +        | 2    3        |  | 
 | 70 | +        | dtype: int64  |  | 
 | 71 | +        └───────────────┘  | 
 | 72 | +        [right]:  | 
 | 73 | +        ┌───────────────┐  | 
 | 74 | +        |Narwhals Series|  | 
 | 75 | +        |---------------|  | 
 | 76 | +        | 0    1        |  | 
 | 77 | +        | 1    5        |  | 
 | 78 | +        | 2    3        |  | 
 | 79 | +        | dtype: int64  |  | 
 | 80 | +        └───────────────┘  | 
 | 81 | +    """  | 
 | 82 | +    __tracebackhide__ = True  | 
 | 83 | + | 
 | 84 | +    if any(not is_narwhals_series(obj) for obj in (left, right)):  | 
 | 85 | +        msg = (  | 
 | 86 | +            "Expected `narwhals.Series` instance, found:\n"  | 
 | 87 | +            f"[left]: {qualified_type_name(type(left))}\n"  | 
 | 88 | +            f"[right]: {qualified_type_name(type(right))}\n\n"  | 
 | 89 | +            "Hint: Use `nw.from_native(obj, series_only=True) to convert each native "  | 
 | 90 | +            "object into a `narwhals.Series` first."  | 
 | 91 | +        )  | 
 | 92 | +        raise TypeError(msg)  | 
 | 93 | + | 
 | 94 | +    _check_metadata(left, right, check_dtypes=check_dtypes, check_names=check_names)  | 
 | 95 | + | 
 | 96 | +    if not check_order:  | 
 | 97 | +        if left.dtype.is_nested():  | 
 | 98 | +            msg = "`check_order=False` is not supported (yet) with nested data type."  | 
 | 99 | +            raise NotImplementedError(msg)  | 
 | 100 | +        left, right = left.sort(), right.sort()  | 
 | 101 | + | 
 | 102 | +    left_vals, right_vals = _check_null_values(left, right)  | 
 | 103 | + | 
 | 104 | +    if check_exact or not left.dtype.is_float():  | 
 | 105 | +        _check_exact_values(  | 
 | 106 | +            left_vals,  | 
 | 107 | +            right_vals,  | 
 | 108 | +            check_dtypes=check_dtypes,  | 
 | 109 | +            check_exact=check_exact,  | 
 | 110 | +            rel_tol=rel_tol,  | 
 | 111 | +            abs_tol=abs_tol,  | 
 | 112 | +            categorical_as_str=categorical_as_str,  | 
 | 113 | +        )  | 
 | 114 | +    else:  | 
 | 115 | +        _check_approximate_values(left_vals, right_vals, rel_tol=rel_tol, abs_tol=abs_tol)  | 
 | 116 | + | 
 | 117 | + | 
 | 118 | +def _check_metadata(  | 
 | 119 | +    left: SeriesT, right: SeriesT, *, check_dtypes: bool, check_names: bool  | 
 | 120 | +) -> None:  | 
 | 121 | +    """Check metadata information: implementation, length, dtype, and names."""  | 
 | 122 | +    left_impl, right_impl = left.implementation, right.implementation  | 
 | 123 | +    if left_impl != right_impl:  | 
 | 124 | +        raise_series_assertion_error("implementation mismatch", left_impl, right_impl)  | 
 | 125 | + | 
 | 126 | +    left_len, right_len = len(left), len(right)  | 
 | 127 | +    if left_len != right_len:  | 
 | 128 | +        raise_series_assertion_error("length mismatch", left_len, right_len)  | 
 | 129 | + | 
 | 130 | +    left_dtype, right_dtype = left.dtype, right.dtype  | 
 | 131 | +    if check_dtypes and left_dtype != right_dtype:  | 
 | 132 | +        raise_series_assertion_error("dtype mismatch", left_dtype, right_dtype)  | 
 | 133 | + | 
 | 134 | +    left_name, right_name = left.name, right.name  | 
 | 135 | +    if check_names and left_name != right_name:  | 
 | 136 | +        raise_series_assertion_error("name mismatch", left_name, right_name)  | 
 | 137 | + | 
 | 138 | + | 
 | 139 | +def _check_null_values(left: SeriesT, right: SeriesT) -> tuple[SeriesT, SeriesT]:  | 
 | 140 | +    """Check null value consistency and return non-null values."""  | 
 | 141 | +    left_null_count, right_null_count = left.null_count(), right.null_count()  | 
 | 142 | +    left_null_mask, right_null_mask = left.is_null(), right.is_null()  | 
 | 143 | + | 
 | 144 | +    if left_null_count != right_null_count or (left_null_mask != right_null_mask).any():  | 
 | 145 | +        raise_series_assertion_error(  | 
 | 146 | +            "null value mismatch", left_null_count, right_null_count  | 
 | 147 | +        )  | 
 | 148 | + | 
 | 149 | +    return left.filter(~left_null_mask), right.filter(~right_null_mask)  | 
 | 150 | + | 
 | 151 | + | 
 | 152 | +def _check_exact_values(  | 
 | 153 | +    left: SeriesT,  | 
 | 154 | +    right: SeriesT,  | 
 | 155 | +    *,  | 
 | 156 | +    check_dtypes: bool,  | 
 | 157 | +    check_exact: bool,  | 
 | 158 | +    rel_tol: float,  | 
 | 159 | +    abs_tol: float,  | 
 | 160 | +    categorical_as_str: bool,  | 
 | 161 | +) -> None:  | 
 | 162 | +    """Check exact value equality for various data types."""  | 
 | 163 | +    left_impl = left.implementation  | 
 | 164 | +    left_dtype, right_dtype = left.dtype, right.dtype  | 
 | 165 | + | 
 | 166 | +    is_not_equal_mask: Series[Any]  | 
 | 167 | +    if left_dtype.is_numeric():  | 
 | 168 | +        # For _all_ numeric dtypes, we can use `is_close` with 0-tolerances to handle  | 
 | 169 | +        # inf and nan values out of the box.  | 
 | 170 | +        is_not_equal_mask = ~left.is_close(right, rel_tol=0, abs_tol=0, nans_equal=True)  | 
 | 171 | +    elif (  | 
 | 172 | +        isinstance(left_dtype, (Array, List)) and isinstance(right_dtype, (Array, List))  | 
 | 173 | +    ) and left_dtype == right_dtype:  | 
 | 174 | +        check_fn = partial(  | 
 | 175 | +            assert_series_equal,  | 
 | 176 | +            check_dtypes=check_dtypes,  | 
 | 177 | +            check_names=False,  | 
 | 178 | +            check_order=True,  | 
 | 179 | +            check_exact=check_exact,  | 
 | 180 | +            rel_tol=rel_tol,  | 
 | 181 | +            abs_tol=abs_tol,  | 
 | 182 | +            categorical_as_str=categorical_as_str,  | 
 | 183 | +        )  | 
 | 184 | +        _check_list_like(left, right, left_dtype, right_dtype, check_fn=check_fn)  | 
 | 185 | +        # If `_check_list_like` didn't raise, then every nested element is equal  | 
 | 186 | +        is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)  | 
 | 187 | +    elif isinstance(left_dtype, Struct) and isinstance(right_dtype, Struct):  | 
 | 188 | +        check_fn = partial(  | 
 | 189 | +            assert_series_equal,  | 
 | 190 | +            check_dtypes=True,  | 
 | 191 | +            check_names=True,  | 
 | 192 | +            check_order=True,  | 
 | 193 | +            check_exact=check_exact,  | 
 | 194 | +            rel_tol=rel_tol,  | 
 | 195 | +            abs_tol=abs_tol,  | 
 | 196 | +            categorical_as_str=categorical_as_str,  | 
 | 197 | +        )  | 
 | 198 | +        _check_struct(left, right, left_dtype, right_dtype, check_fn=check_fn)  | 
 | 199 | +        # If `_check_struct` didn't raise, then every nested element is equal  | 
 | 200 | +        is_not_equal_mask = new_series("", [False], dtype=Boolean(), backend=left_impl)  | 
 | 201 | +    elif isinstance(left_dtype, Categorical) and isinstance(right_dtype, Categorical):  | 
 | 202 | +        # If `_check_categorical` didn't raise, then the categories sources/encodings are  | 
 | 203 | +        # the same, and we can use equality  | 
 | 204 | +        _not_equal = _check_categorical(  | 
 | 205 | +            left, right, categorical_as_str=categorical_as_str  | 
 | 206 | +        )  | 
 | 207 | +        is_not_equal_mask = new_series(  | 
 | 208 | +            "", [_not_equal], dtype=Boolean(), backend=left_impl  | 
 | 209 | +        )  | 
 | 210 | +    else:  | 
 | 211 | +        is_not_equal_mask = left != right  | 
 | 212 | + | 
 | 213 | +    if is_not_equal_mask.any():  | 
 | 214 | +        raise_series_assertion_error("exact value mismatch", left, right)  | 
 | 215 | + | 
 | 216 | + | 
 | 217 | +def _check_approximate_values(  | 
 | 218 | +    left: SeriesT, right: SeriesT, *, rel_tol: float, abs_tol: float  | 
 | 219 | +) -> None:  | 
 | 220 | +    """Check approximate value equality with tolerance."""  | 
 | 221 | +    is_not_close_mask = ~left.is_close(  | 
 | 222 | +        right, rel_tol=rel_tol, abs_tol=abs_tol, nans_equal=True  | 
 | 223 | +    )  | 
 | 224 | + | 
 | 225 | +    if is_not_close_mask.any():  | 
 | 226 | +        raise_series_assertion_error(  | 
 | 227 | +            "values not within tolerance",  | 
 | 228 | +            left.filter(is_not_close_mask),  | 
 | 229 | +            right.filter(is_not_close_mask),  | 
 | 230 | +        )  | 
 | 231 | + | 
 | 232 | + | 
 | 233 | +def _check_list_like(  | 
 | 234 | +    left_vals: SeriesT,  | 
 | 235 | +    right_vals: SeriesT,  | 
 | 236 | +    left_dtype: List | Array,  | 
 | 237 | +    right_dtype: List | Array,  | 
 | 238 | +    check_fn: CheckFn,  | 
 | 239 | +) -> None:  | 
 | 240 | +    # Check row by row after transforming each array/list into a new series.  | 
 | 241 | +    # Notice that order within the array/list must be the same, regardless of  | 
 | 242 | +    # `check_order` value at the top level.  | 
 | 243 | +    impl = left_vals.implementation  | 
 | 244 | +    try:  | 
 | 245 | +        for left_val, right_val in zip_strict(left_vals, right_vals):  | 
 | 246 | +            check_fn(  | 
 | 247 | +                new_series("", values=left_val, dtype=left_dtype.inner, backend=impl),  | 
 | 248 | +                new_series("", values=right_val, dtype=right_dtype.inner, backend=impl),  | 
 | 249 | +            )  | 
 | 250 | +    except AssertionError:  | 
 | 251 | +        raise_series_assertion_error("nested value mismatch", left_vals, right_vals)  | 
 | 252 | + | 
 | 253 | + | 
 | 254 | +def _check_struct(  | 
 | 255 | +    left_vals: SeriesT,  | 
 | 256 | +    right_vals: SeriesT,  | 
 | 257 | +    left_dtype: Struct,  | 
 | 258 | +    right_dtype: Struct,  | 
 | 259 | +    check_fn: CheckFn,  | 
 | 260 | +) -> None:  | 
 | 261 | +    # Check field by field as a separate column.  | 
 | 262 | +    # Notice that for struct's polars raises if:  | 
 | 263 | +    #   * field names are different but values are equal  | 
 | 264 | +    #   * dtype differs, regardless of `check_dtypes=False`  | 
 | 265 | +    #   * order applies only at top level  | 
 | 266 | +    try:  | 
 | 267 | +        for left_field, right_field in zip_strict(left_dtype.fields, right_dtype.fields):  | 
 | 268 | +            check_fn(  | 
 | 269 | +                left_vals.struct.field(left_field.name),  | 
 | 270 | +                right_vals.struct.field(right_field.name),  | 
 | 271 | +            )  | 
 | 272 | +    except AssertionError:  | 
 | 273 | +        raise_series_assertion_error("exact value mismatch", left_vals, right_vals)  | 
 | 274 | + | 
 | 275 | + | 
 | 276 | +def _check_categorical(  | 
 | 277 | +    left_vals: SeriesT, right_vals: SeriesT, *, categorical_as_str: bool  | 
 | 278 | +) -> bool:  | 
 | 279 | +    """Try to compare if any element of categorical series' differ.  | 
 | 280 | +
  | 
 | 281 | +    Inability to compare means that the encoding is different, and an exception is raised.  | 
 | 282 | +    """  | 
 | 283 | +    if categorical_as_str:  | 
 | 284 | +        left_vals, right_vals = left_vals.cast(String()), right_vals.cast(String())  | 
 | 285 | + | 
 | 286 | +    try:  | 
 | 287 | +        return (left_vals != right_vals).any()  | 
 | 288 | +    except Exception as exc:  | 
 | 289 | +        msg = "Cannot compare categoricals coming from different sources."  | 
 | 290 | +        # TODO(FBruzzesi): Improve error message?  | 
 | 291 | +        raise AssertionError(msg) from exc  | 
0 commit comments