diff --git a/src/_pytest/assertion/_compare_set.py b/src/_pytest/assertion/_compare_set.py new file mode 100644 index 00000000000..1cb0980e252 --- /dev/null +++ b/src/_pytest/assertion/_compare_set.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from collections.abc import Callable +from collections.abc import Set as AbstractSet +from typing import Any + +from _pytest._io.saferepr import saferepr +from _pytest.assertion._typing import _HighlightFunc + + +def _set_one_sided_diff( + posn: str, + set1: AbstractSet[Any], + set2: AbstractSet[Any], + highlighter: _HighlightFunc, +) -> list[str]: + explanation = [] + diff = set1 - set2 + if diff: + explanation.append(f"Extra items in the {posn} set:") + for item in diff: + explanation.append(highlighter(saferepr(item))) + return explanation + + +def _compare_eq_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = [] + explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) + explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) + return explanation + + +def _compare_gt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_gte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_lt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_lte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_gte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("right", right, left, highlighter) + + +def _compare_lte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("left", left, right, highlighter) + + +SetComparisonFunction = dict[ + str, + Callable[ + [AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int], + list[str], + ], +] + +SET_COMPARISON_FUNCTIONS: SetComparisonFunction = { + # == can't be done here without a prior refactor because there's an additional + # explanation for iterable in _compare_eq_any + # "==": _compare_eq_set, + "!=": lambda *a, **kw: ["Both sets are equal"], + ">=": _compare_gte_set, + "<=": _compare_lte_set, + ">": _compare_gt_set, + "<": _compare_lt_set, +} diff --git a/src/_pytest/assertion/_typing.py b/src/_pytest/assertion/_typing.py new file mode 100644 index 00000000000..17093f2a931 --- /dev/null +++ b/src/_pytest/assertion/_typing.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Literal +from typing import Protocol + + +class _HighlightFunc(Protocol): # noqa: PYI046 + def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Apply highlighting to the given source.""" diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index cc499f7186f..2086d379421 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -8,12 +8,10 @@ from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Sequence -from collections.abc import Set as AbstractSet import os import pprint from typing import Any from typing import Literal -from typing import Protocol from unicodedata import normalize from _pytest import outcomes @@ -21,6 +19,9 @@ from _pytest._io.pprint import PrettyPrinter from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited +from _pytest.assertion._compare_set import _compare_eq_set +from _pytest.assertion._compare_set import SET_COMPARISON_FUNCTIONS +from _pytest.assertion._typing import _HighlightFunc from _pytest.config import Config @@ -38,11 +39,6 @@ _config: Config | None = None -class _HighlightFunc(Protocol): - def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: - """Apply highlighting to the given source.""" - - def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str: """Dummy highlighter that returns the text unprocessed. @@ -204,30 +200,23 @@ def assertrepr_compare( summary = f"{left_repr} {op} {right_repr}" highlighter = config.get_terminal_writer()._highlight - - explanation = None + explanation: list[str] | None try: - if op == "==": - explanation = _compare_eq_any(left, right, highlighter, verbose) - elif op == "not in": - if istext(left) and istext(right): + match (left, op, right): + case (_, "==", _): + explanation = _compare_eq_any(left, right, highlighter, verbose) + case (str(), "not in", str()): explanation = _notin_text(left, right, verbose) - elif op == "!=": - if isset(left) and isset(right): - explanation = ["Both sets are equal"] - elif op == ">=": - if isset(left) and isset(right): - explanation = _compare_gte_set(left, right, highlighter, verbose) - elif op == "<=": - if isset(left) and isset(right): - explanation = _compare_lte_set(left, right, highlighter, verbose) - elif op == ">": - if isset(left) and isset(right): - explanation = _compare_gt_set(left, right, highlighter, verbose) - elif op == "<": - if isset(left) and isset(right): - explanation = _compare_lt_set(left, right, highlighter, verbose) - + case ( + set() | frozenset(), + "!=" | ">=" | "<=" | ">" | "<", + set() | frozenset(), + ): + explanation = SET_COMPARISON_FUNCTIONS[op]( + left, right, highlighter, verbose + ) + case _: + explanation = None except outcomes.Exit: raise except Exception: @@ -426,75 +415,6 @@ def _compare_eq_sequence( return explanation -def _compare_eq_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = [] - explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) - explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) - return explanation - - -def _compare_gt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_gte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_lt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_lte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_gte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("right", right, left, highlighter) - - -def _compare_lte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("left", left, right, highlighter) - - -def _set_one_sided_diff( - posn: str, - set1: AbstractSet[Any], - set2: AbstractSet[Any], - highlighter: _HighlightFunc, -) -> list[str]: - explanation = [] - diff = set1 - set2 - if diff: - explanation.append(f"Extra items in the {posn} set:") - for item in diff: - explanation.append(highlighter(saferepr(item))) - return explanation - - def _compare_eq_dict( left: Mapping[Any, Any], right: Mapping[Any, Any], diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 2c2830eb929..1e5c6804360 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -1976,10 +1976,10 @@ def f(): def test_exit_from_assertrepr_compare(monkeypatch) -> None: - def raise_exit(obj): + def raise_exit(*args, **kwargs): outcomes.exit("Quitting debugger") - monkeypatch.setattr(util, "istext", raise_exit) + monkeypatch.setattr(util, "_compare_eq_any", raise_exit) with pytest.raises(outcomes.Exit, match="Quitting debugger"): callequal(1, 1)