File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -40,7 +40,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
40
40
41
41
std::vector<T> ids (batch_size);
42
42
for (size_t i = 0 ; i < batch_size; ++i) {
43
- double r = this ->get_rand ();
43
+ double r = this ->getRandReal ();
44
44
int idx = width - 1 ;
45
45
for (int j = 0 ; j < width; ++j) {
46
46
if ((r -= ins_vector[i * width + j]) < 0 ) {
@@ -60,17 +60,23 @@ class SamplingIdKernel : public framework::OpKernel<T> {
60
60
framework::TensorFromVector (ids, context.device_context (), output);
61
61
}
62
62
63
- double get_rand () const {
63
+ private:
64
+ double getRandReal () const {
65
+ std::call_once (init_flag_, &SamplingIdKernel::getRndInstance);
66
+ return rnd ();
67
+ }
68
+
69
+ static void getRndInstance () {
64
70
// Will be used to obtain a seed for the random number engine
65
71
std::random_device rd;
66
72
// Standard mersenne_twister_engine seeded with rd()
67
73
std::mt19937 gen (rd ());
68
74
std::uniform_real_distribution<> dis (0 , 1 );
69
- return dis ( gen);
75
+ rnd = std::bind (dis, gen);
70
76
}
71
77
72
- private:
73
- unsigned int defaultSeed = 0 ;
78
+ static std::once_flag init_flag_;
79
+ static std::function<> rnd ;
74
80
};
75
81
} // namespace operators
76
82
} // namespace paddle
You can’t perform that action at this time.
0 commit comments