From 12eb748856220be1d2634562e222c73820336162 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 28 Sep 2025 07:59:39 +0200 Subject: [PATCH 1/3] [refactor] Create a file for set comparison in assertion --- src/_pytest/assertion/_compare_set.py | 77 +++++++++++++++++++++++++ src/_pytest/assertion/_typing.py | 9 +++ src/_pytest/assertion/util.py | 82 ++------------------------- 3 files changed, 92 insertions(+), 76 deletions(-) create mode 100644 src/_pytest/assertion/_compare_set.py create mode 100644 src/_pytest/assertion/_typing.py diff --git a/src/_pytest/assertion/_compare_set.py b/src/_pytest/assertion/_compare_set.py new file mode 100644 index 00000000000..fe06698b79f --- /dev/null +++ b/src/_pytest/assertion/_compare_set.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Callable +from collections.abc import Set as AbstractSet +from typing import Any + +from _pytest._io.saferepr import saferepr +from _pytest.assertion._typing import _HighlightFunc + + +def _set_one_sided_diff( + posn: str, + set1: AbstractSet[Any], + set2: AbstractSet[Any], + highlighter: _HighlightFunc, +) -> list[str]: + explanation = [] + diff = set1 - set2 + if diff: + explanation.append(f"Extra items in the {posn} set:") + for item in diff: + explanation.append(highlighter(saferepr(item))) + return explanation + + +def _compare_eq_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = [] + explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) + explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) + return explanation + + +def _compare_gt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_gte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_lt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + explanation = _compare_lte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_gte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("right", right, left, highlighter) + + +def _compare_lte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> list[str]: + return _set_one_sided_diff("left", left, right, highlighter) diff --git a/src/_pytest/assertion/_typing.py b/src/_pytest/assertion/_typing.py new file mode 100644 index 00000000000..17093f2a931 --- /dev/null +++ b/src/_pytest/assertion/_typing.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from typing import Literal +from typing import Protocol + + +class _HighlightFunc(Protocol): # noqa: PYI046 + def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Apply highlighting to the given source.""" diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index cc499f7186f..d980740f67f 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -8,12 +8,10 @@ from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Sequence -from collections.abc import Set as AbstractSet import os import pprint from typing import Any from typing import Literal -from typing import Protocol from unicodedata import normalize from _pytest import outcomes @@ -21,6 +19,12 @@ from _pytest._io.pprint import PrettyPrinter from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited +from _pytest.assertion._compare_set import _compare_eq_set +from _pytest.assertion._compare_set import _compare_gt_set +from _pytest.assertion._compare_set import _compare_gte_set +from _pytest.assertion._compare_set import _compare_lt_set +from _pytest.assertion._compare_set import _compare_lte_set +from _pytest.assertion._typing import _HighlightFunc from _pytest.config import Config @@ -38,11 +42,6 @@ _config: Config | None = None -class _HighlightFunc(Protocol): - def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: - """Apply highlighting to the given source.""" - - def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str: """Dummy highlighter that returns the text unprocessed. @@ -426,75 +425,6 @@ def _compare_eq_sequence( return explanation -def _compare_eq_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = [] - explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) - explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) - return explanation - - -def _compare_gt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_gte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_lt_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - explanation = _compare_lte_set(left, right, highlighter) - if not explanation: - return ["Both sets are equal"] - return explanation - - -def _compare_gte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("right", right, left, highlighter) - - -def _compare_lte_set( - left: AbstractSet[Any], - right: AbstractSet[Any], - highlighter: _HighlightFunc, - verbose: int = 0, -) -> list[str]: - return _set_one_sided_diff("left", left, right, highlighter) - - -def _set_one_sided_diff( - posn: str, - set1: AbstractSet[Any], - set2: AbstractSet[Any], - highlighter: _HighlightFunc, -) -> list[str]: - explanation = [] - diff = set1 - set2 - if diff: - explanation.append(f"Extra items in the {posn} set:") - for item in diff: - explanation.append(highlighter(saferepr(item))) - return explanation - - def _compare_eq_dict( left: Mapping[Any, Any], right: Mapping[Any, Any], From 2ff18815b3dff91853b4df05725a01981fe8db30 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sat, 27 Sep 2025 18:14:07 +0200 Subject: [PATCH 2/3] [match case] Use match case in assertrepr_compare --- src/_pytest/assertion/_compare_set.py | 9 ++++++ src/_pytest/assertion/util.py | 43 +++++++++++++-------------- testing/test_assertion.py | 4 +-- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/_pytest/assertion/_compare_set.py b/src/_pytest/assertion/_compare_set.py index fe06698b79f..27a9df017d6 100644 --- a/src/_pytest/assertion/_compare_set.py +++ b/src/_pytest/assertion/_compare_set.py @@ -75,3 +75,12 @@ def _compare_lte_set( verbose: int = 0, ) -> list[str]: return _set_one_sided_diff("left", left, right, highlighter) + + +SetComparisonFunction = dict[ + str, + Callable[ + [AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int], + list[str], + ], +] diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index d980740f67f..ab46dad7fc6 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -24,6 +24,7 @@ from _pytest.assertion._compare_set import _compare_gte_set from _pytest.assertion._compare_set import _compare_lt_set from _pytest.assertion._compare_set import _compare_lte_set +from _pytest.assertion._compare_set import SetComparisonFunction from _pytest.assertion._typing import _HighlightFunc from _pytest.config import Config @@ -203,30 +204,28 @@ def assertrepr_compare( summary = f"{left_repr} {op} {right_repr}" highlighter = config.get_terminal_writer()._highlight - - explanation = None + explanation: list[str] | None try: - if op == "==": - explanation = _compare_eq_any(left, right, highlighter, verbose) - elif op == "not in": - if istext(left) and istext(right): + match (left, op, right): + case (_, "==", _): + explanation = _compare_eq_any(left, right, highlighter, verbose) + case (str(), "not in", str()): explanation = _notin_text(left, right, verbose) - elif op == "!=": - if isset(left) and isset(right): - explanation = ["Both sets are equal"] - elif op == ">=": - if isset(left) and isset(right): - explanation = _compare_gte_set(left, right, highlighter, verbose) - elif op == "<=": - if isset(left) and isset(right): - explanation = _compare_lte_set(left, right, highlighter, verbose) - elif op == ">": - if isset(left) and isset(right): - explanation = _compare_gt_set(left, right, highlighter, verbose) - elif op == "<": - if isset(left) and isset(right): - explanation = _compare_lt_set(left, right, highlighter, verbose) - + case ( + set() | frozenset(), + "!=" | ">=" | "<=" | ">" | "<", + set() | frozenset(), + ): + set_compare_func: SetComparisonFunction = { + "!=": lambda *a, **kw: ["Both sets are equal"], + ">=": _compare_gte_set, + "<=": _compare_lte_set, + ">": _compare_gt_set, + "<": _compare_lt_set, + } + explanation = set_compare_func[op](left, right, highlighter, verbose) + case _: + explanation = None except outcomes.Exit: raise except Exception: diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 2c2830eb929..1e5c6804360 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -1976,10 +1976,10 @@ def f(): def test_exit_from_assertrepr_compare(monkeypatch) -> None: - def raise_exit(obj): + def raise_exit(*args, **kwargs): outcomes.exit("Quitting debugger") - monkeypatch.setattr(util, "istext", raise_exit) + monkeypatch.setattr(util, "_compare_eq_any", raise_exit) with pytest.raises(outcomes.Exit, match="Quitting debugger"): callequal(1, 1) From e35f1d97482ee22abe31eea2ebbd6421db2bb98e Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 28 Sep 2025 07:59:39 +0200 Subject: [PATCH 3/3] [refactor] Create the SET_COMPARISON_FUNCTIONS once Instead of creating it at runtime all the time --- src/_pytest/assertion/_compare_set.py | 11 +++++++++++ src/_pytest/assertion/util.py | 17 ++++------------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/_pytest/assertion/_compare_set.py b/src/_pytest/assertion/_compare_set.py index 27a9df017d6..1cb0980e252 100644 --- a/src/_pytest/assertion/_compare_set.py +++ b/src/_pytest/assertion/_compare_set.py @@ -84,3 +84,14 @@ def _compare_lte_set( list[str], ], ] + +SET_COMPARISON_FUNCTIONS: SetComparisonFunction = { + # == can't be done here without a prior refactor because there's an additional + # explanation for iterable in _compare_eq_any + # "==": _compare_eq_set, + "!=": lambda *a, **kw: ["Both sets are equal"], + ">=": _compare_gte_set, + "<=": _compare_lte_set, + ">": _compare_gt_set, + "<": _compare_lt_set, +} diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index ab46dad7fc6..2086d379421 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -20,11 +20,7 @@ from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited from _pytest.assertion._compare_set import _compare_eq_set -from _pytest.assertion._compare_set import _compare_gt_set -from _pytest.assertion._compare_set import _compare_gte_set -from _pytest.assertion._compare_set import _compare_lt_set -from _pytest.assertion._compare_set import _compare_lte_set -from _pytest.assertion._compare_set import SetComparisonFunction +from _pytest.assertion._compare_set import SET_COMPARISON_FUNCTIONS from _pytest.assertion._typing import _HighlightFunc from _pytest.config import Config @@ -216,14 +212,9 @@ def assertrepr_compare( "!=" | ">=" | "<=" | ">" | "<", set() | frozenset(), ): - set_compare_func: SetComparisonFunction = { - "!=": lambda *a, **kw: ["Both sets are equal"], - ">=": _compare_gte_set, - "<=": _compare_lte_set, - ">": _compare_gt_set, - "<": _compare_lt_set, - } - explanation = set_compare_func[op](left, right, highlighter, verbose) + explanation = SET_COMPARISON_FUNCTIONS[op]( + left, right, highlighter, verbose + ) case _: explanation = None except outcomes.Exit: