Skip to content

Commit c14dc12

Browse files
Add support for Python 3.13
The algorithm selector Enum was failing to initialize provided split algorithms as members due to some changes in Python 3.13. The changes include the conversion of the algorithm implementations to classes with proper hashing capability for better enum support, and tests.
1 parent c7a3272 commit c14dc12

File tree

4 files changed

+160
-79
lines changed

4 files changed

+160
-79
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
55

66
## [Unreleased]
7+
### Added
8+
- Support for Python 3.13.
79

810
## [0.9.0] - 2024-06-19
911
### Changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: 3.10",
2424
"Programming Language :: Python :: 3.11",
2525
"Programming Language :: Python :: 3.12",
26+
"Programming Language :: Python :: 3.13",
2627
"Topic :: Software Development :: Libraries :: Python Modules",
2728
"Typing :: Typed",
2829
]

src/pytest_split/algorithms.py

Lines changed: 104 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import enum
2-
import functools
32
import heapq
3+
from abc import ABC, abstractmethod
44
from operator import itemgetter
55
from typing import TYPE_CHECKING, NamedTuple
66

@@ -16,9 +16,25 @@ class TestGroup(NamedTuple):
1616
duration: float
1717

1818

19-
def least_duration(
20-
splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]"
21-
) -> "List[TestGroup]":
19+
class AlgorithmBase(ABC):
20+
"""Abstract base class for the algorithm implementations."""
21+
22+
@abstractmethod
23+
def __call__(
24+
self, splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]"
25+
) -> "List[TestGroup]":
26+
pass
27+
28+
def __hash__(self) -> int:
29+
return hash(self.__class__.__name__)
30+
31+
def __eq__(self, other: object) -> bool:
32+
if not isinstance(other, AlgorithmBase):
33+
return NotImplemented
34+
return self.__class__.__name__ == other.__class__.__name__
35+
36+
37+
class LeastDurationAlgorithm(AlgorithmBase):
2238
"""
2339
Split tests into groups by runtime.
2440
It walks the test items, starting with the test with largest duration.
@@ -34,60 +50,65 @@ def least_duration(
3450
:return:
3551
List of groups
3652
"""
37-
items_with_durations = _get_items_with_durations(items, durations)
3853

39-
# add index of item in list
40-
items_with_durations_indexed = [
41-
(*tup, i) for i, tup in enumerate(items_with_durations)
42-
]
54+
def __call__(
55+
self, splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]"
56+
) -> "List[TestGroup]":
57+
items_with_durations = _get_items_with_durations(items, durations)
4358

44-
# Sort by name to ensure it's always the same order
45-
items_with_durations_indexed = sorted(
46-
items_with_durations_indexed, key=lambda tup: str(tup[0])
47-
)
48-
49-
# sort in ascending order
50-
sorted_items_with_durations = sorted(
51-
items_with_durations_indexed, key=lambda tup: tup[1], reverse=True
52-
)
53-
54-
selected: List[List[Tuple[nodes.Item, int]]] = [[] for _ in range(splits)]
55-
deselected: List[List[nodes.Item]] = [[] for _ in range(splits)]
56-
duration: List[float] = [0 for _ in range(splits)]
57-
58-
# create a heap of the form (summed_durations, group_index)
59-
heap: List[Tuple[float, int]] = [(0, i) for i in range(splits)]
60-
heapq.heapify(heap)
61-
for item, item_duration, original_index in sorted_items_with_durations:
62-
# get group with smallest sum
63-
summed_durations, group_idx = heapq.heappop(heap)
64-
new_group_durations = summed_durations + item_duration
65-
66-
# store assignment
67-
selected[group_idx].append((item, original_index))
68-
duration[group_idx] = new_group_durations
69-
for i in range(splits):
70-
if i != group_idx:
71-
deselected[i].append(item)
72-
73-
# store new duration - in case of ties it sorts by the group_idx
74-
heapq.heappush(heap, (new_group_durations, group_idx))
75-
76-
groups = []
77-
for i in range(splits):
78-
# sort the items by their original index to maintain relative ordering
79-
# we don't care about the order of deselected items
80-
s = [
81-
item for item, original_index in sorted(selected[i], key=lambda tup: tup[1])
59+
# add index of item in list
60+
items_with_durations_indexed = [
61+
(*tup, i) for i, tup in enumerate(items_with_durations)
8262
]
83-
group = TestGroup(selected=s, deselected=deselected[i], duration=duration[i])
84-
groups.append(group)
85-
return groups
86-
8763

88-
def duration_based_chunks(
89-
splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]"
90-
) -> "List[TestGroup]":
64+
# Sort by name to ensure it's always the same order
65+
items_with_durations_indexed = sorted(
66+
items_with_durations_indexed, key=lambda tup: str(tup[0])
67+
)
68+
69+
# sort in ascending order
70+
sorted_items_with_durations = sorted(
71+
items_with_durations_indexed, key=lambda tup: tup[1], reverse=True
72+
)
73+
74+
selected: List[List[Tuple[nodes.Item, int]]] = [[] for _ in range(splits)]
75+
deselected: List[List[nodes.Item]] = [[] for _ in range(splits)]
76+
duration: List[float] = [0 for _ in range(splits)]
77+
78+
# create a heap of the form (summed_durations, group_index)
79+
heap: List[Tuple[float, int]] = [(0, i) for i in range(splits)]
80+
heapq.heapify(heap)
81+
for item, item_duration, original_index in sorted_items_with_durations:
82+
# get group with smallest sum
83+
summed_durations, group_idx = heapq.heappop(heap)
84+
new_group_durations = summed_durations + item_duration
85+
86+
# store assignment
87+
selected[group_idx].append((item, original_index))
88+
duration[group_idx] = new_group_durations
89+
for i in range(splits):
90+
if i != group_idx:
91+
deselected[i].append(item)
92+
93+
# store new duration - in case of ties it sorts by the group_idx
94+
heapq.heappush(heap, (new_group_durations, group_idx))
95+
96+
groups = []
97+
for i in range(splits):
98+
# sort the items by their original index to maintain relative ordering
99+
# we don't care about the order of deselected items
100+
s = [
101+
item
102+
for item, original_index in sorted(selected[i], key=lambda tup: tup[1])
103+
]
104+
group = TestGroup(
105+
selected=s, deselected=deselected[i], duration=duration[i]
106+
)
107+
groups.append(group)
108+
return groups
109+
110+
111+
class DurationBasedChunksAlgorithm(AlgorithmBase):
91112
"""
92113
Split tests into groups by runtime.
93114
Ensures tests are split into non-overlapping groups.
@@ -99,28 +120,34 @@ def duration_based_chunks(
99120
:param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
100121
:return: List of TestGroup
101122
"""
102-
items_with_durations = _get_items_with_durations(items, durations)
103-
time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits
104-
105-
selected: List[List[nodes.Item]] = [[] for i in range(splits)]
106-
deselected: List[List[nodes.Item]] = [[] for i in range(splits)]
107-
duration: List[float] = [0 for i in range(splits)]
108123

109-
group_idx = 0
110-
for item, item_duration in items_with_durations:
111-
if duration[group_idx] >= time_per_group:
112-
group_idx += 1
113-
114-
selected[group_idx].append(item)
115-
for i in range(splits):
116-
if i != group_idx:
117-
deselected[i].append(item)
118-
duration[group_idx] += item_duration
119-
120-
return [
121-
TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i])
122-
for i in range(splits)
123-
]
124+
def __call__(
125+
self, splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]"
126+
) -> "List[TestGroup]":
127+
items_with_durations = _get_items_with_durations(items, durations)
128+
time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits
129+
130+
selected: List[List[nodes.Item]] = [[] for i in range(splits)]
131+
deselected: List[List[nodes.Item]] = [[] for i in range(splits)]
132+
duration: List[float] = [0 for i in range(splits)]
133+
134+
group_idx = 0
135+
for item, item_duration in items_with_durations:
136+
if duration[group_idx] >= time_per_group:
137+
group_idx += 1
138+
139+
selected[group_idx].append(item)
140+
for i in range(splits):
141+
if i != group_idx:
142+
deselected[i].append(item)
143+
duration[group_idx] += item_duration
144+
145+
return [
146+
TestGroup(
147+
selected=selected[i], deselected=deselected[i], duration=duration[i]
148+
)
149+
for i in range(splits)
150+
]
124151

125152

126153
def _get_items_with_durations(
@@ -153,9 +180,8 @@ def _remove_irrelevant_durations(
153180

154181

155182
class Algorithms(enum.Enum):
156-
# values have to wrapped inside functools to avoid them being considered method definitions
157-
duration_based_chunks = functools.partial(duration_based_chunks)
158-
least_duration = functools.partial(least_duration)
183+
duration_based_chunks = DurationBasedChunksAlgorithm()
184+
least_duration = LeastDurationAlgorithm()
159185

160186
@staticmethod
161187
def names() -> "List[str]":

tests/test_algorithms.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from _pytest.nodes import Item
1111

12-
from pytest_split.algorithms import Algorithms
12+
from pytest_split.algorithms import (
13+
AlgorithmBase,
14+
Algorithms,
15+
)
1316

1417
item = namedtuple("item", "nodeid") # noqa: PYI024
1518

@@ -132,3 +135,52 @@ def test__split_tests_same_set_regardless_of_order(self):
132135
if not selected_each[i]:
133136
selected_each[i] = set(group.selected)
134137
assert selected_each[i] == set(group.selected)
138+
139+
def test__algorithms_members_derived_correctly(self):
140+
for a in Algorithms.names():
141+
assert issubclass(Algorithms[a].value.__class__, AlgorithmBase)
142+
143+
144+
class MyAlgorithm(AlgorithmBase):
145+
def __call__(self, a, b, c):
146+
"""no-op"""
147+
148+
149+
class MyOtherAlgorithm(AlgorithmBase):
150+
def __call__(self, a, b, c):
151+
"""no-op"""
152+
153+
154+
class TestAbstractAlgorithm:
155+
def test__hash__returns_correct_result(self):
156+
algo = MyAlgorithm()
157+
assert algo.__hash__() == hash(algo.__class__.__name__)
158+
159+
def test__hash__returns_same_hash_for_same_class_instances(self):
160+
algo1 = MyAlgorithm()
161+
algo2 = MyAlgorithm()
162+
assert algo1.__hash__() == algo2.__hash__()
163+
164+
def test__hash__returns_different_hash_for_different_classes(self):
165+
algo1 = MyAlgorithm()
166+
algo2 = MyOtherAlgorithm()
167+
assert algo1.__hash__() != algo2.__hash__()
168+
169+
def test__eq__returns_true_for_same_instance(self):
170+
algo = MyAlgorithm()
171+
assert algo.__eq__(algo) is True
172+
173+
def test__eq__returns_false_for_different_instance(self):
174+
algo1 = MyAlgorithm()
175+
algo2 = MyOtherAlgorithm()
176+
assert algo1.__eq__(algo2) is False
177+
178+
def test__eq__returns_true_for_same_algorithm_different_instance(self):
179+
algo1 = MyAlgorithm()
180+
algo2 = MyAlgorithm()
181+
assert algo1.__eq__(algo2) is True
182+
183+
def test__eq__returns_false_for_non_algorithm_object(self):
184+
algo = MyAlgorithm()
185+
other = "not an algorithm"
186+
assert algo.__eq__(other) is NotImplemented

0 commit comments

Comments
 (0)