Skip to content

Commit 9c63fef

Browse files
committed
random optimize
1 parent 5b9716d commit 9c63fef

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

paddle/fluid/operators/sampling_id_op.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
4040

4141
std::vector<T> ids(batch_size);
4242
for (size_t i = 0; i < batch_size; ++i) {
43-
double r = this->get_rand();
43+
double r = this->getRandReal();
4444
int idx = width - 1;
4545
for (int j = 0; j < width; ++j) {
4646
if ((r -= ins_vector[i * width + j]) < 0) {
@@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> {
6060
framework::TensorFromVector(ids, context.device_context(), output);
6161
}
6262

63-
double get_rand() const {
63+
private:
64+
double getRandReal() const {
65+
std::call_once(init_flag_, &SamplingIdKernel::getRndInstance);
66+
return rnd();
67+
}
68+
69+
static void getRndInstance() {
6470
// Will be used to obtain a seed for the random number engine
6571
std::random_device rd;
6672
// Standard mersenne_twister_engine seeded with rd()
6773
std::mt19937 gen(rd());
6874
std::uniform_real_distribution<> dis(0, 1);
69-
return dis(gen);
75+
rnd = std::bind(dis, gen);
7076
}
7177

72-
private:
73-
unsigned int defaultSeed = 0;
78+
static std::once_flag init_flag_;
79+
static std::function<> rnd;
7480
};
7581
} // namespace operators
7682
} // namespace paddle

0 commit comments

Comments
 (0)