@@ -15,30 +15,31 @@ limitations under the License. */
15
15
16
16
#include < random>
17
17
#include < vector>
18
+ #include " paddle/fluid/framework/lod_tensor.h"
18
19
#include " paddle/fluid/framework/op_registry.h"
19
20
20
21
namespace paddle {
21
22
namespace operators {
22
23
24
+ using Tensor = framework::Tensor;
25
+
23
26
template <typename DeviceContext, typename T>
24
27
class SamplingIdKernel : public framework ::OpKernel<T> {
25
- // / Produces random floating-point values, uniformly distributed on [0, 1).
26
- std::uniform_real_distribution<double > rand1_;
27
-
28
28
public:
29
29
void Compute (const framework::ExecutionContext& context) const override {
30
30
const Tensor* input = context.Input <Tensor>(" X" );
31
31
const int batch_size = static_cast <int >(input->dims ()[0 ]);
32
32
const int width = static_cast <int >(input->dims ()[1 ]);
33
33
34
- std::vector<int > ids (batchSize) ;
35
- auto & reng = get ( );
34
+ std::vector<T> ins_vector ;
35
+ framework::TensorToVector (*input, context. device_context (), &ins_vector );
36
36
37
- for (size_t i = 0 ; i < batchSize; ++i) {
38
- double r = rand1_ (reng);
39
- int id = dim - 1 ;
40
- for (int j = 0 ; j < dim; ++j) {
41
- if ((r -= buf[i * dim + j]) < 0 ) {
37
+ std::vector<int > ids (batch_size);
38
+ for (size_t i = 0 ; i < batch_size; ++i) {
39
+ double r = this ->get_rand ();
40
+ int id = width - 1 ;
41
+ for (int j = 0 ; j < width; ++j) {
42
+ if ((r -= ins_vector[i * width + j]) < 0 ) {
42
43
id = j;
43
44
break ;
44
45
}
@@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> {
50
51
out_dim.push_back (static_cast <int64_t >(batch_size));
51
52
52
53
Tensor* output = context.Output <Tensor>(" Output" );
53
- output->Resize (framework::make_ddim (in_dim ));
54
+ output->Resize (framework::make_ddim (out_dim ));
54
55
output->mutable_data <T>(context.GetPlace ());
55
56
framework::TensorFromVector (ids, context.device_context (), output);
56
57
}
57
58
58
- std::default_random_engine& get () {
59
- auto engine = new std::default_random_engine;
60
- engine->seed (defaultSeed);
61
- return *engine;
59
+ double get_rand () const {
60
+ // Will be used to obtain a seed for the random number engine
61
+ std::random_device rd;
62
+ // Standard mersenne_twister_engine seeded with rd()
63
+ std::mt19937 gen (rd ());
64
+ std::uniform_real_distribution<> dis (0 , 1 );
65
+ return dis (gen);
62
66
}
63
67
64
68
private:
65
69
unsigned int defaultSeed = 0 ;
66
- }
70
+ };
67
71
} // namespace operators
68
72
} // namespace paddle
0 commit comments