11import numpy
22import pytest
33
4+ from . import exceptions
45from . import thompson
56
67
@@ -15,11 +16,12 @@ def test_sample_batch(self, batch_size, seed):
1516 low = [0 , 0 , - 1 ],
1617 high = [0.2 , 1 , 0 ],
1718 size = (S , C )
18- )
19- numpy .testing .assert_array_equal (samples .shape , (S , C ))
19+ ). T
20+ numpy .testing .assert_array_equal (samples .shape , (C , S ))
2021
2122 batch = thompson .sample_batch (
22- samples = samples , ids = ids ,
23+ candidate_samples = samples , ids = ids ,
24+ correlated = False ,
2325 batch_size = batch_size , seed = seed
2426 )
2527 assert len (batch ) == batch_size
@@ -29,26 +31,126 @@ def test_sample_batch(self, batch_size, seed):
2931 pass
3032
3133 def test_no_bias_on_sample_collisions (self ):
32- samples = numpy . array ( [
34+ samples = [
3335 [2 , 2 , 2 ],
36+ [2 , 2 ],
3437 [2 , 2 , 2 ],
35- ])
36- batch = thompson .sample_batch (samples , ids = ('A' , 'B' , 'C' ), batch_size = 100 , seed = 1234 )
38+ ]
39+ batch = thompson .sample_batch (samples , ids = ('A' , 'B' , 'C' ), correlated = False , batch_size = 100 , seed = 1234 )
3740 assert batch .count ('A' ) != 100
3841 assert batch .count ('C' ) != 0
3942 pass
4043
41- @pytest .mark .xfail (reason = 'Probabilities are currently computed by brute force and non-exact.' )
42- def test_get_probabilities_exact_on_identical (self ):
43- samples = numpy .array ([
44- [1 , 2 , 3 , 4 , 5 ],
45- [5 , 3 , 4 , 2 , 1 ],
46- [1 , 3 , 4 , 2 , 5 ]
47- ]).T
48- S , C = samples .shape
49- assert S == 5
50- assert C == 3
44+ def test_correlated_sampling (self ):
45+ samples = [
46+ [1 , 2 , 3 ],
47+ [1 , 1 , 1 ],
48+ [0 , 1 , 2 ],
49+ ]
50+ batch = thompson .sample_batch (samples , ids = ('A' , 'B' , 'C' ), correlated = True , batch_size = 100 , seed = 1234 )
51+ assert batch .count ('A' ) < 100
52+ assert batch .count ('B' ) < 100 / 3
53+ assert batch .count ('C' ) == 0
54+ pass
55+
56+
57+ class TestExceptions :
58+ def test_id_count (self ):
59+ with pytest .raises (exceptions .ShapeError , match = "candidate ids" ):
60+ thompson .sample_batch ([
61+ [1 ,2 ,3 ],
62+ [1 ,2 ],
63+ ],
64+ ids = ("A" , "B" , "C" ),
65+ correlated = False ,
66+ batch_size = 30 ,
67+ )
68+
69+ def test_correlated_sample_size_check (self ):
70+ with pytest .raises (exceptions .ShapeError , match = "number of samples" ):
71+ thompson .sample_batch ([
72+ [1 ,2 ,3 ],
73+ [1 ,2 ],
74+ ],
75+ ids = ("A" , "B" ),
76+ correlated = True ,
77+ batch_size = 30 ,
78+ )
79+
80+ with pytest .raises (exceptions .ShapeError ):
81+ thompson .sampling_probabilities ([
82+ [1 ,2 ,3 ],
83+ [1 ,2 ],
84+ ],
85+ correlated = True ,
86+ )
87+ pass
88+
89+
90+ class TestThompsonProbabilities :
91+ def test_sort_samples (self ):
92+ samples , sample_cols = thompson ._sort_samples ([
93+ [3 ,1 ,2 ],
94+ [4 ,- 1 ],
95+ [7 ],
96+ ])
97+ numpy .testing .assert_array_equal (samples , [- 1 , 1 , 2 , 3 , 4 , 7 ])
98+ numpy .testing .assert_array_equal (sample_cols , [1 , 0 , 0 , 0 , 1 , 2 ])
99+ pass
100+
101+ def test_win_draw_prob (self ):
102+ assert thompson ._win_draw_prob (numpy .array ([
103+ [1 , 0 , 0 ],
104+ [0 , 1 , 1 ],
105+ [0 , 0 , 0 ],
106+ ])) == 0.0
107+
108+ assert thompson ._win_draw_prob (numpy .array ([
109+ [0 , 0 , 0 ],
110+ [0 , 0 , 0 ],
111+ [1 , 1 , 1 ],
112+ ])) == 0.25
113+
114+ numpy .testing .assert_allclose (thompson ._win_draw_prob (numpy .array ([
115+ [0 , 0 ],
116+ [0.5 , 0.75 ],
117+ [0.5 , 0.25 ],
118+ ])), 0.041666666 )
119+ pass
120+
121+ def test_sampling_probability_uncorrelated (self ):
122+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
123+ [0 , 1 , 2 ],
124+ [0 , 1 , 2 ],
125+ ], correlated = False ), [0.5 , 0.5 ])
126+
127+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
128+ [0 , 1 , 2 ],
129+ [10 ],
130+ ], correlated = False ), [0 , 1 ])
131+
132+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
133+ [0 , 1 , 2 ],
134+ [3 , 4 , 5 ],
135+ [5 , 4 , 3 ],
136+ ], correlated = False ), [0 , 0.5 , 0.5 ])
137+
138+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
139+ [5 , 6 ],
140+ [0 , 0 , 10 , 20 ],
141+ [5 , 6 ],
142+ ], correlated = False ), [0.25 , 0.5 , 0.25 ])
143+ pass
144+
145+ def test_sampling_probability_correlated (self ):
146+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
147+ [0 , 1 , 2 ],
148+ [0 , 1 , 2 ],
149+ ], correlated = True ), [0.5 , 0.5 ])
51150
52- probabilities = thompson .get_probabilities (samples )
53- numpy .testing .assert_array_equal (probabilities , [1 / C ]* C )
151+ numpy .testing .assert_array_equal (thompson .sampling_probabilities ([
152+ [0 , 4 , 2 ],
153+ [3 , 4 , 5 ],
154+ [5 , 1 , 6 ],
155+ ], correlated = True ), [0.5 / 3 , 0.5 / 3 , 2 / 3 ])
54156 pass
0 commit comments