Skip to content

Commit 3941fd3

Browse files
committed
Added typehints and some more tests.
1 parent 36ca942 commit 3941fd3

File tree

3 files changed

+108
-21
lines changed

3 files changed

+108
-21
lines changed

beetsplug/random.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
"""Get a random song or album from the library."""
1616

17+
from __future__ import annotations
18+
1719
import random
18-
from itertools import groupby
20+
from itertools import groupby, islice
1921
from operator import attrgetter
22+
from typing import Iterable, Sequence, TypeVar
2023

24+
from beets.library import Album, Item
2125
from beets.plugins import BeetsPlugin
2226
from beets.ui import Subcommand, print_
2327

@@ -69,15 +73,19 @@ def commands(self):
6973
return [random_cmd]
7074

7175

72-
def _length(obj, album):
76+
def _length(obj: Item | Album) -> float:
7377
"""Get the duration of an item or album."""
74-
if album:
78+
if isinstance(obj, Album):
7579
return sum(i.length for i in obj.items())
7680
else:
7781
return obj.length
7882

7983

80-
def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
84+
def _equal_chance_permutation(
85+
objs: Sequence[Item | Album],
86+
field: str = "albumartist",
87+
random_gen: random.Random | None = None,
88+
) -> Iterable[Item | Album]:
8189
"""Generate (lazily) a permutation of the objects where every group
8290
with equal values for `field` have an equal chance of appearing in
8391
any given position.
@@ -86,7 +94,7 @@ def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
8694

8795
# Group the objects by artist so we can sample from them.
8896
key = attrgetter(field)
89-
objs.sort(key=key)
97+
objs = sorted(objs, key=key)
9098
objs_by_artists = {}
9199
for artist, v in groupby(objs, key):
92100
objs_by_artists[artist] = list(v)
@@ -106,36 +114,43 @@ def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
106114
del objs_by_artists[artist]
107115

108116

109-
def _take(iter, num):
117+
T = TypeVar("T")
118+
119+
120+
def _take(
121+
iter: Iterable[T],
122+
num: int,
123+
) -> list[T]:
110124
"""Return a list containing the first `num` values in `iter` (or
111125
fewer, if the iterable ends early).
112126
"""
113-
out = []
114-
for val in iter:
115-
out.append(val)
116-
num -= 1
117-
if num <= 0:
118-
break
119-
return out
127+
return list(islice(iter, num))
120128

121129

122-
def _take_time(iter, secs, album):
130+
def _take_time(
131+
iter: Iterable[Item | Album],
132+
secs: float,
133+
) -> list[Item | Album]:
123134
"""Return a list containing the first values in `iter`, which should
124135
be Item or Album objects, that add up to the given amount of time in
125136
seconds.
126137
"""
127-
out = []
138+
out: list[Item | Album] = []
128139
total_time = 0.0
129140
for obj in iter:
130-
length = _length(obj, album)
141+
length = _length(obj)
131142
if total_time + length <= secs:
132143
out.append(obj)
133144
total_time += length
134145
return out
135146

136147

137148
def random_objs(
138-
objs, album, number=1, time=None, equal_chance=False, random_gen=None
149+
objs: Sequence[Item | Album],
150+
number=1,
151+
time: float | None = None,
152+
equal_chance: bool = False,
153+
random_gen: random.Random | None = None,
139154
):
140155
"""Get a random subset of the provided `objs`.
141156
@@ -152,11 +167,11 @@ def random_objs(
152167
if equal_chance:
153168
perm = _equal_chance_permutation(objs)
154169
else:
155-
perm = objs
156-
rand.shuffle(perm) # N.B. This shuffles the original list.
170+
perm = list(objs)
171+
rand.shuffle(perm)
157172

158173
# Select objects by time our count.
159174
if time:
160-
return _take_time(perm, time * 60, album)
175+
return _take_time(perm, time * 60)
161176
else:
162177
return _take(perm, number)

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ Other changes:
107107
* Refactored library.py file by splitting it into multiple modules within the
108108
beets/library directory.
109109
* Added a test to check that all plugins can be imported without errors.
110+
* Moved `beets/random.py` into `beetsplug/random.py` to cleanup core module.
110111

111112
2.3.1 (May 14, 2025)
112113
--------------------

test/plugins/test_random.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import pytest
2222

23-
from beets import random
2423
from beets.test.helper import TestHelper
24+
from beetsplug import random
2525

2626

2727
class RandomTest(TestHelper, unittest.TestCase):
@@ -77,3 +77,74 @@ def experiment(field, histogram=False):
7777
assert 0 == pytest.approx(median1, abs=1)
7878
assert len(self.items) // 2 == pytest.approx(median2, abs=1)
7979
assert stdev2 > stdev1
80+
81+
def test_equal_permutation_empty_input(self):
82+
"""Test _equal_chance_permutation with empty input."""
83+
result = list(random._equal_chance_permutation([], "artist"))
84+
assert result == []
85+
86+
def test_equal_permutation_single_item(self):
87+
"""Test _equal_chance_permutation with single item."""
88+
result = list(random._equal_chance_permutation([self.item1], "artist"))
89+
assert result == [self.item1]
90+
91+
def test_equal_permutation_single_artist(self):
92+
"""Test _equal_chance_permutation with items from one artist."""
93+
items = [self.create_item(artist=self.artist1) for _ in range(5)]
94+
result = list(random._equal_chance_permutation(items, "artist"))
95+
assert set(result) == set(items)
96+
assert len(result) == len(items)
97+
98+
def test_random_objs_count(self):
99+
"""Test random_objs with count-based selection."""
100+
result = random.random_objs(
101+
self.items, number=3, random_gen=self.random_gen
102+
)
103+
assert len(result) == 3
104+
assert all(item in self.items for item in result)
105+
106+
def test_random_objs_time(self):
107+
"""Test random_objs with time-based selection."""
108+
# Total length is 30 + 60 + 8*45 = 450 seconds
109+
# Requesting 120 seconds should return 2-3 items
110+
result = random.random_objs(
111+
self.items,
112+
time=2,
113+
random_gen=self.random_gen, # 2 minutes = 120 sec
114+
)
115+
total_time = sum(item.length for item in result)
116+
assert total_time <= 120
117+
# Check we got at least some items
118+
assert len(result) > 0
119+
120+
def test_random_objs_equal_chance(self):
121+
"""Test random_objs with equal_chance=True."""
122+
123+
# With equal_chance, artist1 should appear more often in results
124+
def experiment():
125+
"""Run the random_objs function multiple times and collect results."""
126+
results = []
127+
for _ in range(5000):
128+
result = random.random_objs(
129+
[self.item1, self.item2],
130+
number=1,
131+
equal_chance=True,
132+
random_gen=self.random_gen,
133+
)
134+
results.append(result[0].artist)
135+
136+
# Return ratio
137+
return results.count(self.artist1), results.count(self.artist2)
138+
139+
count_artist1, count_artist2 = experiment()
140+
assert 1 - count_artist1 / count_artist2 < 0.1 # 10% deviation
141+
142+
def test_random_objs_empty_input(self):
143+
"""Test random_objs with empty input."""
144+
result = random.random_objs([], number=3)
145+
assert result == []
146+
147+
def test_random_objs_zero_number(self):
148+
"""Test random_objs with number=0."""
149+
result = random.random_objs(self.items, number=0)
150+
assert result == []

0 commit comments

Comments
 (0)