Skip to content

Commit ac4f06b

Browse files
Merge pull request #24 from michaelosthege/fix-3
Correlated & uncorrelated TS with exact probability calculation
2 parents eae0dbf + e322ceb commit ac4f06b

File tree

5 files changed

+348
-52
lines changed

5 files changed

+348
-52
lines changed

.github/workflows/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
export NUMBA_DISABLE_JIT=1
3434
pytest --cov=./pyrff --cov-report xml --cov-report term-missing pyrff/
3535
- name: Upload coverage
36-
uses: codecov/codecov-action@v1.0.7
36+
uses: codecov/codecov-action@v1
3737
if: matrix.python-version == 3.8
3838
with:
3939
file: ./coverage.xml

pyrff/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from . exceptions import DtypeError, ShapeError
22
from . rff import sample_rff, save_rffs, load_rffs
3-
from . thompson import sample_batch, get_probabilities
3+
from . thompson import sample_batch, sampling_probabilities
44
from . utils import multi_start_fmin
55

6-
__version__ = '1.0.1'
6+
__version__ = '2.0.0'

pyrff/test_thompson.py

Lines changed: 120 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy
22
import pytest
33

4+
from . import exceptions
45
from . import thompson
56

67

@@ -15,11 +16,12 @@ def test_sample_batch(self, batch_size, seed):
1516
low=[0, 0, -1],
1617
high=[0.2, 1, 0],
1718
size=(S, C)
18-
)
19-
numpy.testing.assert_array_equal(samples.shape, (S, C))
19+
).T
20+
numpy.testing.assert_array_equal(samples.shape, (C, S))
2021

2122
batch = thompson.sample_batch(
22-
samples=samples, ids=ids,
23+
candidate_samples=samples, ids=ids,
24+
correlated=False,
2325
batch_size=batch_size, seed=seed
2426
)
2527
assert len(batch) == batch_size
@@ -29,26 +31,126 @@ def test_sample_batch(self, batch_size, seed):
2931
pass
3032

3133
def test_no_bias_on_sample_collisions(self):
32-
samples = numpy.array([
34+
samples = [
3335
[2, 2, 2],
36+
[2, 2],
3437
[2, 2, 2],
35-
])
36-
batch = thompson.sample_batch(samples, ids=('A', 'B', 'C'), batch_size=100, seed=1234)
38+
]
39+
batch = thompson.sample_batch(samples, ids=('A', 'B', 'C'), correlated=False, batch_size=100, seed=1234)
3740
assert batch.count('A') != 100
3841
assert batch.count('C') != 0
3942
pass
4043

41-
@pytest.mark.xfail(reason='Probabilities are currently computed by brute force and non-exact.')
42-
def test_get_probabilities_exact_on_identical(self):
43-
samples = numpy.array([
44-
[1, 2, 3, 4, 5],
45-
[5, 3, 4, 2, 1],
46-
[1, 3, 4, 2, 5]
47-
]).T
48-
S, C = samples.shape
49-
assert S == 5
50-
assert C == 3
44+
def test_correlated_sampling(self):
45+
samples = [
46+
[1, 2, 3],
47+
[1, 1, 1],
48+
[0, 1, 2],
49+
]
50+
batch = thompson.sample_batch(samples, ids=('A', 'B', 'C'), correlated=True, batch_size=100, seed=1234)
51+
assert batch.count('A') < 100
52+
assert batch.count('B') < 100 / 3
53+
assert batch.count('C') == 0
54+
pass
55+
56+
57+
class TestExceptions:
58+
def test_id_count(self):
59+
with pytest.raises(exceptions.ShapeError, match="candidate ids"):
60+
thompson.sample_batch([
61+
[1,2,3],
62+
[1,2],
63+
],
64+
ids=("A", "B", "C"),
65+
correlated=False,
66+
batch_size=30,
67+
)
68+
69+
def test_correlated_sample_size_check(self):
70+
with pytest.raises(exceptions.ShapeError, match="number of samples"):
71+
thompson.sample_batch([
72+
[1,2,3],
73+
[1,2],
74+
],
75+
ids=("A", "B"),
76+
correlated=True,
77+
batch_size=30,
78+
)
79+
80+
with pytest.raises(exceptions.ShapeError):
81+
thompson.sampling_probabilities([
82+
[1,2,3],
83+
[1,2],
84+
],
85+
correlated=True,
86+
)
87+
pass
88+
89+
90+
class TestThompsonProbabilities:
91+
def test_sort_samples(self):
92+
samples, sample_cols = thompson._sort_samples([
93+
[3,1,2],
94+
[4,-1],
95+
[7],
96+
])
97+
numpy.testing.assert_array_equal(samples, [-1, 1, 2, 3, 4, 7])
98+
numpy.testing.assert_array_equal(sample_cols, [1, 0, 0, 0, 1, 2])
99+
pass
100+
101+
def test_win_draw_prob(self):
102+
assert thompson._win_draw_prob(numpy.array([
103+
[1, 0, 0],
104+
[0, 1, 1],
105+
[0, 0, 0],
106+
])) == 0.0
107+
108+
assert thompson._win_draw_prob(numpy.array([
109+
[0, 0, 0],
110+
[0, 0, 0],
111+
[1, 1, 1],
112+
])) == 0.25
113+
114+
numpy.testing.assert_allclose(thompson._win_draw_prob(numpy.array([
115+
[0, 0],
116+
[0.5, 0.75],
117+
[0.5, 0.25],
118+
])), 0.041666666)
119+
pass
120+
121+
def test_sampling_probability_uncorrelated(self):
122+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
123+
[0, 1, 2],
124+
[0, 1, 2],
125+
], correlated=False), [0.5, 0.5])
126+
127+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
128+
[0, 1, 2],
129+
[10],
130+
], correlated=False), [0, 1])
131+
132+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
133+
[0, 1, 2],
134+
[3, 4, 5],
135+
[5, 4, 3],
136+
], correlated=False), [0, 0.5, 0.5])
137+
138+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
139+
[5, 6],
140+
[0, 0, 10, 20],
141+
[5, 6],
142+
], correlated=False), [0.25, 0.5, 0.25])
143+
pass
144+
145+
def test_sampling_probability_correlated(self):
146+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
147+
[0, 1, 2],
148+
[0, 1, 2],
149+
], correlated=True), [0.5, 0.5])
51150

52-
probabilities = thompson.get_probabilities(samples)
53-
numpy.testing.assert_array_equal(probabilities, [1/C]*C)
151+
numpy.testing.assert_array_equal(thompson.sampling_probabilities([
152+
[0, 4, 2],
153+
[3, 4, 5],
154+
[5, 1, 6],
155+
], correlated=True), [0.5/3, 0.5/3, 2/3])
54156
pass

0 commit comments

Comments
 (0)