Skip to content

Commit 69c36ce

Browse files
Fix seeding issues with correlated Thompson sampling (#29)
* Add regression test for #28 * Fix seeding problem in correlated Thompson sampling Closes #28
1 parent 78acd67 commit 69c36ce

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

pyrff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from . thompson import sample_batch, sampling_probabilities
44
from . utils import multi_start_fmin
55

6-
__version__ = '2.0.1'
6+
__version__ = '2.0.2'

pyrff/test_thompson.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,30 @@ def test_correlated_sampling(self):
5353
assert batch.count('C') == 0
5454
pass
5555

56+
@pytest.mark.parametrize("correlated", [False, True])
57+
def test_seeding(self, correlated):
58+
"""This is a regression test for https://github.com/michaelosthege/pyrff/issues/28"""
59+
rng = numpy.random.RandomState(123)
60+
samples = [
61+
rng.uniform(size=200),
62+
rng.uniform(size=200),
63+
rng.uniform(size=200),
64+
]
65+
batches = []
66+
for _ in range(10):
67+
batch = thompson.sample_batch(
68+
candidate_samples=samples,
69+
ids=["A", "B", "C"],
70+
correlated=correlated,
71+
batch_size=10,
72+
seed=123,
73+
)
74+
batches.append("".join(batch))
75+
76+
# Assert that all batches are identical.
77+
assert len(set(batches)) == 1
78+
pass
79+
5680

5781
class TestExceptions:
5882
def test_id_count(self):

pyrff/thompson.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def sample_batch(
6161
# to prevent always selecting lower-numbered candidates when >=2 samples are equal
6262
col_order = random.permutation(n_candidates)
6363
if correlated:
64-
idx = numpy.repeat(numpy.random.randint(n_samples[0]), n_candidates)
64+
idx = numpy.repeat(random.randint(n_samples[0]), n_candidates)
6565
else:
6666
idx = random.randint(n_samples, size=n_candidates)
6767
selected_samples = samples[:, col_order][idx, numpy.arange(n_candidates)]

0 commit comments

Comments
 (0)