@@ -41,39 +41,42 @@ class MultinomialSamplerTester : public MultinomialSampler {
41
41
TEST (MultinomialSampler, gen) {
42
42
int numGrids = 1024 * 1024 ;
43
43
int size = 1024 * 4 ;
44
-
45
44
default_random_engine reng;
46
- uniform_int_distribution<int > rand (1 , numGrids / size * 1.8 );
47
- vector<real> prob;
48
- int sum = 0 ;
49
- for (int i = 0 ; i < size; ++i) {
50
- prob.push_back (rand (reng));
51
- sum += prob.back ();
52
- }
53
- CHECK_LE (sum, numGrids);
54
- prob.back () += numGrids - sum;
55
45
56
- vector<int > counts (size);
57
- MultinomialSamplerTester sampler (&prob[0 ], size);
58
- counts.assign (size, 0 );
59
- {
60
- double s = (double )size / (double )numGrids;
61
- REGISTER_TIMER (" MultinomialSampler" );
62
- for (double i = 0 ; i < numGrids; ++i) {
63
- int ret = sampler.testGen ([i, s]() { return s * i; });
64
- if (ret < 0 || ret >= size) {
65
- EXPECT_GE (ret, 0 );
66
- EXPECT_LT (ret, size);
67
- break ;
46
+ for (size_t iter=0 ; iter < 256 ; ++iter) {
47
+ uniform_int_distribution<int > rand (1 , numGrids / size * 1.8 );
48
+ vector<real> prob;
49
+ int sum = 0 ;
50
+ for (int i = 0 ; i < size; ++i) {
51
+ prob.push_back (rand (reng));
52
+ sum += prob.back ();
53
+ }
54
+
55
+ CHECK_LE (sum, numGrids);
56
+ prob.back () += numGrids - sum;
57
+
58
+ vector<int > counts (size);
59
+ MultinomialSamplerTester sampler (&prob[0 ], size);
60
+ counts.assign (size, 0 );
61
+ {
62
+ double s = (double )size / (double )numGrids;
63
+ REGISTER_TIMER (" MultinomialSampler" );
64
+ for (double i = 0 ; i < numGrids; ++i) {
65
+ int ret = sampler.testGen ([i, s]() { return s * i; });
66
+ if (ret < 0 || ret >= size) {
67
+ EXPECT_GE (ret, 0 );
68
+ EXPECT_LT (ret, size);
69
+ break ;
70
+ }
71
+ ++counts[ret];
68
72
}
69
- ++counts[ret];
70
73
}
71
- }
72
- for ( int i = 0 ; i < size; ++i ) {
73
- if (prob[i] != counts[i]) {
74
- EXPECT_EQ (prob[i], counts[i]) ;
75
- LOG (INFO) << " i= " << i ;
76
- break ;
74
+ for ( int i = 0 ; i < size; ++i) {
75
+ if (prob[i] != counts[i] ) {
76
+ EXPECT_EQ (prob[i], counts[i]);
77
+ LOG (INFO) << iter ;
78
+ break ;
79
+ }
77
80
}
78
81
}
79
82
}
@@ -135,6 +138,7 @@ void benchmarkRandom() {
135
138
LOG (INFO) << " sum1=" << sum1;
136
139
}
137
140
141
+
138
142
int main (int argc, char ** argv) {
139
143
initMain (argc, argv);
140
144
testing::InitGoogleTest (&argc, argv);
0 commit comments