Skip to content

Commit 7eb29f2

Browse files
reyoungemailweixu
authored andcommitted
Try to fix MultinomialSampler (#102)
* Also refine unittest to multiple iteration to prevent luckily random number.
1 parent 8e957df commit 7eb29f2

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

paddle/gserver/layers/MultinomialSampler.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace paddle {
1919

2020
MultinomialSampler::MultinomialSampler(const real* prob, int size)
2121
: rand_(0.0, size) {
22-
intervals_.reserve(size + 1);
22+
intervals_.resize(size + 1);
2323
double sum = 0;
2424
for (int i = 0; i < size; ++i) {
2525
sum += prob[i];
@@ -50,12 +50,13 @@ MultinomialSampler::MultinomialSampler(const real* prob, int size)
5050
int bigPos = nextBigPos(0);
5151

5252
auto fillIntervals = [&]() {
53-
while (bigPos < size && smallPos < size) {
53+
while (bigPos < size) {
5454
while (intervals_[bigPos].thresh > 1 && smallPos < size) {
5555
intervals_[smallPos].otherId = bigPos;
5656
intervals_[bigPos].thresh -= 1 - intervals_[smallPos].thresh;
5757
smallPos = nextSmallPos(smallPos + 1);
5858
}
59+
if (smallPos >= size) break;
5960
bigPos = nextBigPos(bigPos + 1);
6061
// If intervals_[bigPos].thresh < 1, it becomes a small interval
6162
}

paddle/gserver/tests/test_MultinomialSampler.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,42 @@ class MultinomialSamplerTester : public MultinomialSampler {
4141
TEST(MultinomialSampler, gen) {
4242
int numGrids = 1024 * 1024;
4343
int size = 1024 * 4;
44-
4544
default_random_engine reng;
46-
uniform_int_distribution<int> rand(1, numGrids / size * 1.8);
47-
vector<real> prob;
48-
int sum = 0;
49-
for (int i = 0; i < size; ++i) {
50-
prob.push_back(rand(reng));
51-
sum += prob.back();
52-
}
53-
CHECK_LE(sum, numGrids);
54-
prob.back() += numGrids - sum;
5545

56-
vector<int> counts(size);
57-
MultinomialSamplerTester sampler(&prob[0], size);
58-
counts.assign(size, 0);
59-
{
60-
double s = (double)size / (double)numGrids;
61-
REGISTER_TIMER("MultinomialSampler");
62-
for (double i = 0; i < numGrids; ++i) {
63-
int ret = sampler.testGen([i, s]() { return s * i; });
64-
if (ret < 0 || ret >= size) {
65-
EXPECT_GE(ret, 0);
66-
EXPECT_LT(ret, size);
67-
break;
46+
for (size_t iter=0; iter < 256; ++iter) {
47+
uniform_int_distribution<int> rand(1, numGrids / size * 1.8);
48+
vector<real> prob;
49+
int sum = 0;
50+
for (int i = 0; i < size; ++i) {
51+
prob.push_back(rand(reng));
52+
sum += prob.back();
53+
}
54+
55+
CHECK_LE(sum, numGrids);
56+
prob.back() += numGrids - sum;
57+
58+
vector<int> counts(size);
59+
MultinomialSamplerTester sampler(&prob[0], size);
60+
counts.assign(size, 0);
61+
{
62+
double s = (double)size / (double)numGrids;
63+
REGISTER_TIMER("MultinomialSampler");
64+
for (double i = 0; i < numGrids; ++i) {
65+
int ret = sampler.testGen([i, s]() { return s * i; });
66+
if (ret < 0 || ret >= size) {
67+
EXPECT_GE(ret, 0);
68+
EXPECT_LT(ret, size);
69+
break;
70+
}
71+
++counts[ret];
6872
}
69-
++counts[ret];
7073
}
71-
}
72-
for (int i = 0; i < size; ++i) {
73-
if (prob[i] != counts[i]) {
74-
EXPECT_EQ(prob[i], counts[i]);
75-
LOG(INFO) << "i=" << i;
76-
break;
74+
for (int i = 0; i < size; ++i) {
75+
if (prob[i] != counts[i]) {
76+
EXPECT_EQ(prob[i], counts[i]);
77+
LOG(INFO) << iter;
78+
break;
79+
}
7780
}
7881
}
7982
}
@@ -135,6 +138,7 @@ void benchmarkRandom() {
135138
LOG(INFO) << "sum1=" << sum1;
136139
}
137140

141+
138142
int main(int argc, char** argv) {
139143
initMain(argc, argv);
140144
testing::InitGoogleTest(&argc, argv);

0 commit comments

Comments
 (0)