Skip to content

Commit 13aa1e0

Browse files
committed
Added typehints and some more tests.
1 parent 36ca942 commit 13aa1e0

File tree

2 files changed

+103
-21
lines changed

2 files changed

+103
-21
lines changed

beetsplug/random.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
"""Get a random song or album from the library."""
1616

17+
from __future__ import annotations
18+
1719
import random
18-
from itertools import groupby
20+
from itertools import groupby, islice
1921
from operator import attrgetter
22+
from typing import Iterable, Sequence
2023

24+
from beets.library import Album, Item
2125
from beets.plugins import BeetsPlugin
2226
from beets.ui import Subcommand, print_
2327

@@ -69,15 +73,19 @@ def commands(self):
6973
return [random_cmd]
7074

7175

72-
def _length(obj, album):
76+
def _length(obj: Item | Album) -> float:
7377
"""Get the duration of an item or album."""
74-
if album:
78+
if isinstance(obj, Album):
7579
return sum(i.length for i in obj.items())
7680
else:
7781
return obj.length
7882

7983

80-
def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
84+
def _equal_chance_permutation(
85+
objs: Sequence[Item | Album],
86+
field: str = "albumartist",
87+
random_gen: random.Random | None = None,
88+
) -> Iterable[Item | Album]:
8189
"""Generate (lazily) a permutation of the objects where every group
8290
with equal values for `field` have an equal chance of appearing in
8391
any given position.
@@ -86,7 +94,7 @@ def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
8694

8795
# Group the objects by artist so we can sample from them.
8896
key = attrgetter(field)
89-
objs.sort(key=key)
97+
objs = sorted(objs, key=key)
9098
objs_by_artists = {}
9199
for artist, v in groupby(objs, key):
92100
objs_by_artists[artist] = list(v)
@@ -106,36 +114,40 @@ def _equal_chance_permutation(objs, field="albumartist", random_gen=None):
106114
del objs_by_artists[artist]
107115

108116

109-
def _take(iter, num):
117+
def _take(
118+
iter: Iterable,
119+
num: int,
120+
) -> list:
110121
"""Return a list containing the first `num` values in `iter` (or
111122
fewer, if the iterable ends early).
112123
"""
113-
out = []
114-
for val in iter:
115-
out.append(val)
116-
num -= 1
117-
if num <= 0:
118-
break
119-
return out
124+
return list(islice(iter, num))
120125

121126

122-
def _take_time(iter, secs, album):
127+
def _take_time(
128+
iter: Iterable[Item | Album],
129+
secs: float,
130+
) -> list[Item | Album]:
123131
"""Return a list containing the first values in `iter`, which should
124132
be Item or Album objects, that add up to the given amount of time in
125133
seconds.
126134
"""
127-
out = []
135+
out: list[Item | Album] = []
128136
total_time = 0.0
129137
for obj in iter:
130-
length = _length(obj, album)
138+
length = _length(obj)
131139
if total_time + length <= secs:
132140
out.append(obj)
133141
total_time += length
134142
return out
135143

136144

137145
def random_objs(
138-
objs, album, number=1, time=None, equal_chance=False, random_gen=None
146+
objs: Sequence[Item | Album],
147+
number=1,
148+
time: float | None = None,
149+
equal_chance: bool = False,
150+
random_gen: random.Random | None = None,
139151
):
140152
"""Get a random subset of the provided `objs`.
141153
@@ -152,11 +164,11 @@ def random_objs(
152164
if equal_chance:
153165
perm = _equal_chance_permutation(objs)
154166
else:
155-
perm = objs
156-
rand.shuffle(perm) # N.B. This shuffles the original list.
167+
perm = list(objs)
168+
rand.shuffle(perm)
157169

158170
# Select objects by time our count.
159171
if time:
160-
return _take_time(perm, time * 60, album)
172+
return _take_time(perm, time * 60)
161173
else:
162174
return _take(perm, number)

test/plugins/test_random.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import pytest
2222

23-
from beets import random
2423
from beets.test.helper import TestHelper
24+
from beetsplug import random
2525

2626

2727
class RandomTest(TestHelper, unittest.TestCase):
@@ -77,3 +77,73 @@ def experiment(field, histogram=False):
7777
assert 0 == pytest.approx(median1, abs=1)
7878
assert len(self.items) // 2 == pytest.approx(median2, abs=1)
7979
assert stdev2 > stdev1
80+
81+
def test_equal_permutation_empty_input(self):
82+
"""Test _equal_chance_permutation with empty input."""
83+
result = list(random._equal_chance_permutation([], "artist"))
84+
assert result == []
85+
86+
def test_equal_permutation_single_item(self):
87+
"""Test _equal_chance_permutation with single item."""
88+
result = list(random._equal_chance_permutation([self.item1], "artist"))
89+
assert result == [self.item1]
90+
91+
def test_equal_permutation_single_artist(self):
92+
"""Test _equal_chance_permutation with items from one artist."""
93+
items = [self.create_item(artist=self.artist1) for _ in range(5)]
94+
result = list(random._equal_chance_permutation(items, "artist"))
95+
assert set(result) == set(items)
96+
assert len(result) == len(items)
97+
98+
def test_random_objs_count(self):
99+
"""Test random_objs with count-based selection."""
100+
result = random.random_objs(
101+
self.items, number=3, random_gen=self.random_gen
102+
)
103+
assert len(result) == 3
104+
assert all(item in self.items for item in result)
105+
106+
def test_random_objs_time(self):
107+
"""Test random_objs with time-based selection."""
108+
# Total length is 30 + 60 + 8*45 = 450 seconds
109+
# Requesting 120 seconds should return 2-3 items
110+
result = random.random_objs(
111+
self.items,
112+
time=2,
113+
random_gen=self.random_gen, # 2 minutes = 120 sec
114+
)
115+
total_time = sum(item.length for item in result)
116+
assert total_time <= 120
117+
# Check we got at least some items
118+
assert len(result) > 0
119+
120+
def test_random_objs_equal_chance(self):
121+
"""Test random_objs with equal_chance=True."""
122+
123+
# With equal_chance, artist1 should appear more often in results
124+
def experiment():
125+
"""Run the random_objs function multiple times and collect results."""
126+
results = []
127+
for _ in range(5000):
128+
result = random.random_objs(
129+
[self.item1, self.item2],
130+
number=1,
131+
equal_chance=True,
132+
random_gen=self.random_gen,
133+
)
134+
results.append(result[0].artist)
135+
136+
return results.count(self.artist1), results.count(self.artist2)
137+
138+
count_artist1, count_artist2 = experiment()
139+
assert abs(count_artist1 - count_artist2) < 100 # 2% deviation
140+
141+
def test_random_objs_empty_input(self):
142+
"""Test random_objs with empty input."""
143+
result = random.random_objs([], number=3)
144+
assert result == []
145+
146+
def test_random_objs_zero_number(self):
147+
"""Test random_objs with number=0."""
148+
result = random.random_objs(self.items, number=0)
149+
assert result == []

0 commit comments

Comments
 (0)