Skip to content

Commit ac60fb8

Browse files
[refactor] Create the SET_COMPARISON_FUNCTIONS once
Instead of creating it at runtime all the time
1 parent 0611968 commit ac60fb8

File tree

3 files changed

+112
-93
lines changed

3 files changed

+112
-93
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from collections.abc import Set as AbstractSet
5+
from typing import Any
6+
7+
from _pytest._io.saferepr import saferepr
8+
from _pytest.assertion._typing import _HighlightFunc
9+
10+
11+
def _set_one_sided_diff(
12+
posn: str,
13+
set1: AbstractSet[Any],
14+
set2: AbstractSet[Any],
15+
highlighter: _HighlightFunc,
16+
) -> list[str]:
17+
explanation = []
18+
diff = set1 - set2
19+
if diff:
20+
explanation.append(f"Extra items in the {posn} set:")
21+
for item in diff:
22+
explanation.append(highlighter(saferepr(item)))
23+
return explanation
24+
25+
26+
def _compare_eq_set(
27+
left: AbstractSet[Any],
28+
right: AbstractSet[Any],
29+
highlighter: _HighlightFunc,
30+
verbose: int = 0,
31+
) -> list[str]:
32+
explanation = []
33+
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
34+
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
35+
return explanation
36+
37+
38+
def _compare_gt_set(
39+
left: AbstractSet[Any],
40+
right: AbstractSet[Any],
41+
highlighter: _HighlightFunc,
42+
verbose: int = 0,
43+
) -> list[str]:
44+
explanation = _compare_gte_set(left, right, highlighter)
45+
if not explanation:
46+
return ["Both sets are equal"]
47+
return explanation
48+
49+
50+
def _compare_lt_set(
51+
left: AbstractSet[Any],
52+
right: AbstractSet[Any],
53+
highlighter: _HighlightFunc,
54+
verbose: int = 0,
55+
) -> list[str]:
56+
explanation = _compare_lte_set(left, right, highlighter)
57+
if not explanation:
58+
return ["Both sets are equal"]
59+
return explanation
60+
61+
62+
def _compare_gte_set(
63+
left: AbstractSet[Any],
64+
right: AbstractSet[Any],
65+
highlighter: _HighlightFunc,
66+
verbose: int = 0,
67+
) -> list[str]:
68+
return _set_one_sided_diff("right", right, left, highlighter)
69+
70+
71+
def _compare_lte_set(
72+
left: AbstractSet[Any],
73+
right: AbstractSet[Any],
74+
highlighter: _HighlightFunc,
75+
verbose: int = 0,
76+
) -> list[str]:
77+
return _set_one_sided_diff("left", left, right, highlighter)
78+
79+
80+
SetComparisonFunction = dict[
81+
str,
82+
Callable[
83+
[AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int],
84+
list[str],
85+
],
86+
]
87+
88+
SET_COMPARISON_FUNCTIONS: SetComparisonFunction = {
89+
# == can't be done here without a prior refactor because there's an additional
90+
# explanation for iterable in _compare_eq_any
91+
# "==": _compare_eq_set,
92+
"!=": lambda *a, **kw: ["Both sets are equal"],
93+
">=": _compare_gte_set,
94+
"<=": _compare_lte_set,
95+
">": _compare_gt_set,
96+
"<": _compare_lt_set,
97+
}

src/_pytest/assertion/_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
from typing import Protocol
5+
6+
7+
class _HighlightFunc(Protocol): # noqa: PYI046
8+
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
9+
"""Apply highlighting to the given source."""

src/_pytest/assertion/util.py

Lines changed: 6 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
from collections.abc import Iterable
99
from collections.abc import Mapping
1010
from collections.abc import Sequence
11-
from collections.abc import Set as AbstractSet
1211
import os
1312
import pprint
1413
from typing import Any
1514
from typing import Literal
16-
from typing import Protocol
1715
from unicodedata import normalize
1816

1917
from _pytest import outcomes
2018
import _pytest._code
2119
from _pytest._io.pprint import PrettyPrinter
2220
from _pytest._io.saferepr import saferepr
2321
from _pytest._io.saferepr import saferepr_unlimited
22+
from _pytest.assertion._compare_set import _compare_eq_set
23+
from _pytest.assertion._compare_set import SET_COMPARISON_FUNCTIONS
24+
from _pytest.assertion._typing import _HighlightFunc
2425
from _pytest.config import Config
2526

2627

@@ -38,20 +39,6 @@
3839
_config: Config | None = None
3940

4041

41-
class _HighlightFunc(Protocol):
42-
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
43-
"""Apply highlighting to the given source."""
44-
45-
46-
CompareSetFunction = dict[
47-
str,
48-
Callable[
49-
[AbstractSet[Any], AbstractSet[Any], _HighlightFunc, int],
50-
list[str],
51-
],
52-
]
53-
54-
5542
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
5643
"""Dummy highlighter that returns the text unprocessed.
5744
@@ -225,14 +212,9 @@ def assertrepr_compare(
225212
"!=" | ">=" | "<=" | ">" | "<",
226213
set() | frozenset(),
227214
):
228-
set_compare_func: CompareSetFunction = {
229-
"!=": lambda *a, **kw: ["Both sets are equal"],
230-
">=": _compare_gte_set,
231-
"<=": _compare_lte_set,
232-
">": _compare_gt_set,
233-
"<": _compare_lt_set,
234-
}
235-
explanation = set_compare_func[op](left, right, highlighter, verbose)
215+
explanation = SET_COMPARISON_FUNCTIONS[op](
216+
left, right, highlighter, verbose
217+
)
236218
case _:
237219
explanation = None
238220
except outcomes.Exit:
@@ -433,75 +415,6 @@ def _compare_eq_sequence(
433415
return explanation
434416

435417

436-
def _compare_eq_set(
437-
left: AbstractSet[Any],
438-
right: AbstractSet[Any],
439-
highlighter: _HighlightFunc,
440-
verbose: int = 0,
441-
) -> list[str]:
442-
explanation = []
443-
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
444-
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
445-
return explanation
446-
447-
448-
def _compare_gt_set(
449-
left: AbstractSet[Any],
450-
right: AbstractSet[Any],
451-
highlighter: _HighlightFunc,
452-
verbose: int = 0,
453-
) -> list[str]:
454-
explanation = _compare_gte_set(left, right, highlighter)
455-
if not explanation:
456-
return ["Both sets are equal"]
457-
return explanation
458-
459-
460-
def _compare_lt_set(
461-
left: AbstractSet[Any],
462-
right: AbstractSet[Any],
463-
highlighter: _HighlightFunc,
464-
verbose: int = 0,
465-
) -> list[str]:
466-
explanation = _compare_lte_set(left, right, highlighter)
467-
if not explanation:
468-
return ["Both sets are equal"]
469-
return explanation
470-
471-
472-
def _compare_gte_set(
473-
left: AbstractSet[Any],
474-
right: AbstractSet[Any],
475-
highlighter: _HighlightFunc,
476-
verbose: int = 0,
477-
) -> list[str]:
478-
return _set_one_sided_diff("right", right, left, highlighter)
479-
480-
481-
def _compare_lte_set(
482-
left: AbstractSet[Any],
483-
right: AbstractSet[Any],
484-
highlighter: _HighlightFunc,
485-
verbose: int = 0,
486-
) -> list[str]:
487-
return _set_one_sided_diff("left", left, right, highlighter)
488-
489-
490-
def _set_one_sided_diff(
491-
posn: str,
492-
set1: AbstractSet[Any],
493-
set2: AbstractSet[Any],
494-
highlighter: _HighlightFunc,
495-
) -> list[str]:
496-
explanation = []
497-
diff = set1 - set2
498-
if diff:
499-
explanation.append(f"Extra items in the {posn} set:")
500-
for item in diff:
501-
explanation.append(highlighter(saferepr(item)))
502-
return explanation
503-
504-
505418
def _compare_eq_dict(
506419
left: Mapping[Any, Any],
507420
right: Mapping[Any, Any],

0 commit comments

Comments
 (0)