Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src/_pytest/assertion/_compare_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Set as AbstractSet
from typing import Any

from _pytest._io.saferepr import saferepr
from _pytest.assertion._typing import _HighlightFunc


def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation


def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
return explanation


def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation


def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation


def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)


def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter)


SetComparisonFunction = dict[
str,
Callable[
[AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int],
list[str],
],
]

SET_COMPARISON_FUNCTIONS: SetComparisonFunction = {
# == can't be done here without a prior refactor because there's an additional
# explanation for iterable in _compare_eq_any
# "==": _compare_eq_set,
"!=": lambda *a, **kw: ["Both sets are equal"],
">=": _compare_gte_set,
"<=": _compare_lte_set,
">": _compare_gt_set,
"<": _compare_lt_set,
}
9 changes: 9 additions & 0 deletions src/_pytest/assertion/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing import Literal
from typing import Protocol


class _HighlightFunc(Protocol): # noqa: PYI046
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""
116 changes: 18 additions & 98 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
import os
import pprint
from typing import Any
from typing import Literal
from typing import Protocol
from unicodedata import normalize

from _pytest import outcomes
import _pytest._code
from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.assertion._compare_set import _compare_eq_set
from _pytest.assertion._compare_set import SET_COMPARISON_FUNCTIONS
from _pytest.assertion._typing import _HighlightFunc
from _pytest.config import Config


Expand All @@ -38,11 +39,6 @@
_config: Config | None = None


class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""


def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Dummy highlighter that returns the text unprocessed.

Expand Down Expand Up @@ -204,30 +200,23 @@ def assertrepr_compare(

summary = f"{left_repr} {op} {right_repr}"
highlighter = config.get_terminal_writer()._highlight

explanation = None
explanation: list[str] | None
try:
if op == "==":
explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
if istext(left) and istext(right):
match (left, op, right):
case (_, "==", _):
explanation = _compare_eq_any(left, right, highlighter, verbose)
case (str(), "not in", str()):
explanation = _notin_text(left, right, verbose)
elif op == "!=":
if isset(left) and isset(right):
explanation = ["Both sets are equal"]
elif op == ">=":
if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=":
if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">":
if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<":
if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose)

case (
set() | frozenset(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This duplicates the isset logic, while that seems unlikely to change it seems conceptually wrong to duplicate since the idea is to have a single place which defines what a set is. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I have a follow up:

commit ae1fbffd1b5e7b2dc54c5cccda218f1da5173a80
Author: Pierre Sassoulas <[email protected]>
Date:   Sun Sep 28 07:40:28 2025 +0200

    [refactor] Group set together and two less isinstance check

diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py
index 2086d3794..3dff746f2 100644
--- a/src/_pytest/assertion/util.py
+++ b/src/_pytest/assertion/util.py
@@ -126,10 +126,6 @@ def isdict(x: Any) -> bool:
     return isinstance(x, dict)
 
 
-def isset(x: Any) -> bool:
-    return isinstance(x, set | frozenset)
-
-
 def isnamedtuple(obj: Any) -> bool:
     return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
 
@@ -203,18 +199,18 @@ def assertrepr_compare(
     explanation: list[str] | None
     try:
         match (left, op, right):
-            case (_, "==", _):
-                explanation = _compare_eq_any(left, right, highlighter, verbose)
-            case (str(), "not in", str()):
-                explanation = _notin_text(left, right, verbose)
             case (
                 set() | frozenset(),
-                "!=" | ">=" | "<=" | ">" | "<",
+                "==" | "!=" | ">=" | "<=" | ">" | "<",
                 set() | frozenset(),
             ):
                 explanation = SET_COMPARISON_FUNCTIONS[op](
                     left, right, highlighter, verbose
                 )
+            case (_, "==", _):
+                explanation = _compare_eq_any(left, right, highlighter, verbose)
+            case (str(), "not in", str()):
+                explanation = _notin_text(left, right, verbose)
             case _:
                 explanation = None
     except outcomes.Exit:
@@ -259,8 +255,6 @@ def _compare_eq_any(
             explanation = _compare_eq_cls(left, right, highlighter, verbose)
         elif issequence(left) and issequence(right):
             explanation = _compare_eq_sequence(left, right, highlighter, verbose)
-        elif isset(left) and isset(right):
-            explanation = _compare_eq_set(left, right, highlighter, verbose)
         elif isdict(left) and isdict(right):
             explanation = _compare_eq_dict(left, right, highlighter, verbose)

But this require a prior refactor because there's additional information that are added only for iterable in _compare_eq_any. Actually finding the right order of refactors for optimal reviewing ease was a little tricky and I decided to open this one instead of entering rebase hell.

The goal would be to remove all the isx function once the match case structure make them redundant because the complexity can be contained in the match case itself while still being readable.

Copy link
Member Author

@Pierre-Sassoulas Pierre-Sassoulas Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wondering if merging the entire _compare_eq_any function in the match case of assertrepr_compare is going to make it too big. But it's probably worth it because the whole logic for the assert_repr decision tree is easy to locate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your follow up makes sense to me (moving the type-specific checking before the generic _ == _ case).

I understand your comment about rebase hell, though it would be helpful to see the final state if you have it, even if it's one huge commit blob :)

BTW, regarding isset, I wonder if we should use collections.abc.Set (immutable set operations, includes set and frozenset, and possibly user types which implement the interface) instead of set | frozenset. I haven't checked if it makes sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed the concept a little and _compare_eq_any is recursive through _compare_eq_cls (launched on each element inside iterables) so what I envisioned is not possible. I'm going to make a second switch to match/case in _compare_eq_any and keep them separated. Here's a messy unfinished glob with what I currently have: https://github.com/Pierre-Sassoulas/pytest/pull/2/files#diff-0e1605330bae69222e30a50a4c573ae5eadb529e8f64662c0c9134e431af9d4aR27-R120

"!=" | ">=" | "<=" | ">" | "<",
set() | frozenset(),
):
explanation = SET_COMPARISON_FUNCTIONS[op](
left, right, highlighter, verbose
)
case _:
explanation = None
except outcomes.Exit:
raise
except Exception:
Expand Down Expand Up @@ -426,75 +415,6 @@ def _compare_eq_sequence(
return explanation


def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
return explanation


def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation


def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation


def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)


def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter)


def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation


def _compare_eq_dict(
left: Mapping[Any, Any],
right: Mapping[Any, Any],
Expand Down
4 changes: 2 additions & 2 deletions testing/test_assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,10 +1976,10 @@ def f():


def test_exit_from_assertrepr_compare(monkeypatch) -> None:
def raise_exit(obj):
def raise_exit(*args, **kwargs):
outcomes.exit("Quitting debugger")

monkeypatch.setattr(util, "istext", raise_exit)
monkeypatch.setattr(util, "_compare_eq_any", raise_exit)

with pytest.raises(outcomes.Exit, match="Quitting debugger"):
callequal(1, 1)
Expand Down
Loading