Skip to content

Commit 12eb748

Browse files
[refactor] Create a file for set comparison in assertion
1 parent bd7d709 commit 12eb748

File tree

3 files changed

+92
-76
lines changed

3 files changed

+92
-76
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from collections.abc import Set as AbstractSet
5+
from typing import Any
6+
7+
from _pytest._io.saferepr import saferepr
8+
from _pytest.assertion._typing import _HighlightFunc
9+
10+
11+
def _set_one_sided_diff(
12+
posn: str,
13+
set1: AbstractSet[Any],
14+
set2: AbstractSet[Any],
15+
highlighter: _HighlightFunc,
16+
) -> list[str]:
17+
explanation = []
18+
diff = set1 - set2
19+
if diff:
20+
explanation.append(f"Extra items in the {posn} set:")
21+
for item in diff:
22+
explanation.append(highlighter(saferepr(item)))
23+
return explanation
24+
25+
26+
def _compare_eq_set(
27+
left: AbstractSet[Any],
28+
right: AbstractSet[Any],
29+
highlighter: _HighlightFunc,
30+
verbose: int = 0,
31+
) -> list[str]:
32+
explanation = []
33+
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
34+
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
35+
return explanation
36+
37+
38+
def _compare_gt_set(
39+
left: AbstractSet[Any],
40+
right: AbstractSet[Any],
41+
highlighter: _HighlightFunc,
42+
verbose: int = 0,
43+
) -> list[str]:
44+
explanation = _compare_gte_set(left, right, highlighter)
45+
if not explanation:
46+
return ["Both sets are equal"]
47+
return explanation
48+
49+
50+
def _compare_lt_set(
51+
left: AbstractSet[Any],
52+
right: AbstractSet[Any],
53+
highlighter: _HighlightFunc,
54+
verbose: int = 0,
55+
) -> list[str]:
56+
explanation = _compare_lte_set(left, right, highlighter)
57+
if not explanation:
58+
return ["Both sets are equal"]
59+
return explanation
60+
61+
62+
def _compare_gte_set(
63+
left: AbstractSet[Any],
64+
right: AbstractSet[Any],
65+
highlighter: _HighlightFunc,
66+
verbose: int = 0,
67+
) -> list[str]:
68+
return _set_one_sided_diff("right", right, left, highlighter)
69+
70+
71+
def _compare_lte_set(
72+
left: AbstractSet[Any],
73+
right: AbstractSet[Any],
74+
highlighter: _HighlightFunc,
75+
verbose: int = 0,
76+
) -> list[str]:
77+
return _set_one_sided_diff("left", left, right, highlighter)

src/_pytest/assertion/_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
from typing import Protocol
5+
6+
7+
class _HighlightFunc(Protocol): # noqa: PYI046
8+
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
9+
"""Apply highlighting to the given source."""

src/_pytest/assertion/util.py

Lines changed: 6 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@
88
from collections.abc import Iterable
99
from collections.abc import Mapping
1010
from collections.abc import Sequence
11-
from collections.abc import Set as AbstractSet
1211
import os
1312
import pprint
1413
from typing import Any
1514
from typing import Literal
16-
from typing import Protocol
1715
from unicodedata import normalize
1816

1917
from _pytest import outcomes
2018
import _pytest._code
2119
from _pytest._io.pprint import PrettyPrinter
2220
from _pytest._io.saferepr import saferepr
2321
from _pytest._io.saferepr import saferepr_unlimited
22+
from _pytest.assertion._compare_set import _compare_eq_set
23+
from _pytest.assertion._compare_set import _compare_gt_set
24+
from _pytest.assertion._compare_set import _compare_gte_set
25+
from _pytest.assertion._compare_set import _compare_lt_set
26+
from _pytest.assertion._compare_set import _compare_lte_set
27+
from _pytest.assertion._typing import _HighlightFunc
2428
from _pytest.config import Config
2529

2630

@@ -38,11 +42,6 @@
3842
_config: Config | None = None
3943

4044

41-
class _HighlightFunc(Protocol):
42-
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
43-
"""Apply highlighting to the given source."""
44-
45-
4645
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
4746
"""Dummy highlighter that returns the text unprocessed.
4847
@@ -426,75 +425,6 @@ def _compare_eq_sequence(
426425
return explanation
427426

428427

429-
def _compare_eq_set(
430-
left: AbstractSet[Any],
431-
right: AbstractSet[Any],
432-
highlighter: _HighlightFunc,
433-
verbose: int = 0,
434-
) -> list[str]:
435-
explanation = []
436-
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
437-
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
438-
return explanation
439-
440-
441-
def _compare_gt_set(
442-
left: AbstractSet[Any],
443-
right: AbstractSet[Any],
444-
highlighter: _HighlightFunc,
445-
verbose: int = 0,
446-
) -> list[str]:
447-
explanation = _compare_gte_set(left, right, highlighter)
448-
if not explanation:
449-
return ["Both sets are equal"]
450-
return explanation
451-
452-
453-
def _compare_lt_set(
454-
left: AbstractSet[Any],
455-
right: AbstractSet[Any],
456-
highlighter: _HighlightFunc,
457-
verbose: int = 0,
458-
) -> list[str]:
459-
explanation = _compare_lte_set(left, right, highlighter)
460-
if not explanation:
461-
return ["Both sets are equal"]
462-
return explanation
463-
464-
465-
def _compare_gte_set(
466-
left: AbstractSet[Any],
467-
right: AbstractSet[Any],
468-
highlighter: _HighlightFunc,
469-
verbose: int = 0,
470-
) -> list[str]:
471-
return _set_one_sided_diff("right", right, left, highlighter)
472-
473-
474-
def _compare_lte_set(
475-
left: AbstractSet[Any],
476-
right: AbstractSet[Any],
477-
highlighter: _HighlightFunc,
478-
verbose: int = 0,
479-
) -> list[str]:
480-
return _set_one_sided_diff("left", left, right, highlighter)
481-
482-
483-
def _set_one_sided_diff(
484-
posn: str,
485-
set1: AbstractSet[Any],
486-
set2: AbstractSet[Any],
487-
highlighter: _HighlightFunc,
488-
) -> list[str]:
489-
explanation = []
490-
diff = set1 - set2
491-
if diff:
492-
explanation.append(f"Extra items in the {posn} set:")
493-
for item in diff:
494-
explanation.append(highlighter(saferepr(item)))
495-
return explanation
496-
497-
498428
def _compare_eq_dict(
499429
left: Mapping[Any, Any],
500430
right: Mapping[Any, Any],

0 commit comments

Comments
 (0)