@@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel<T> {
36
36
std::vector<T> ins_vector;
37
37
framework::TensorToVector (*input, context.device_context (), &ins_vector);
38
38
39
+ unsigned int seed = static_cast <unsigned int >(ctx.Attr <int >(" seed" ));
40
+ std::minstd_rand engine;
41
+ if (seed == 0 ) {
42
+ seed = std::random_device ()();
43
+ }
44
+ engine.seed (seed);
45
+ std::uniform_real_distribution<T> dist (
46
+ static_cast <T>(ctx.Attr <float >(" min" )),
47
+ static_cast <T>(ctx.Attr <float >(" max" )));
48
+
39
49
std::vector<T> ids (batch_size);
40
50
for (size_t i = 0 ; i < batch_size; ++i) {
41
- double r = getRandReal ( );
51
+ double r = dist (engine );
42
52
int idx = width - 1 ;
43
53
for (int j = 0 ; j < width; ++j) {
44
54
if ((r -= ins_vector[i * width + j]) < 0 ) {
@@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
57
67
output->mutable_data <T>(context.GetPlace ());
58
68
framework::TensorFromVector (ids, context.device_context (), output);
59
69
}
60
-
61
- private:
62
- double getRandReal () const {
63
- std::random_device
64
- rd; // Will be used to obtain a seed for the random number engine
65
- std::mt19937 gen (rd ()); // Standard mersenne_twister_engine seeded with
66
- // rd()
67
- std::uniform_real_distribution<> dis (1.0 , 2.0 );
68
- return dis (gen);
69
- }
70
70
};
71
71
72
72
class SamplingIdOp : public framework ::OperatorWithKernel {
@@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
78
78
" Input(X) of SamplingIdOp should not be null." );
79
79
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
80
80
" Output(Out) of SamplingIdOp should not be null." );
81
+ PADDLE_ENFORCE (
82
+ ctx->Attrs ().Get <float >(" min" ) < ctx->Attrs ().Get <float >(" max" ),
83
+ " min must less then max" );
81
84
82
85
auto input_dims = ctx->GetInputDim (" X" );
83
86
PADDLE_ENFORCE (input_dims.size () == 2 ,
@@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
99
102
AddComment (R"DOC(
100
103
SamplingId Operator.
101
104
A layer for sampling id from multinomial distribution from the
102
- input layer. Sampling one id for one sample.)DOC" );
105
+ input. Sampling one id for one sample.)DOC" );
106
+ AddAttr<float >(" min" , " Minimum value of random. [default 0.0]." )
107
+ .SetDefault (0 .0f );
108
+ AddAttr<float >(" max" , " Maximun value of random. [default 1.0]." )
109
+ .SetDefault (1 .0f );
110
+ AddAttr<int >(" seed" ,
111
+ " Random seed used for the random number engine. "
112
+ " 0 means use a seed generated by the system."
113
+ " Note that if seed is not 0, this operator will always "
114
+ " generate the same random numbers every time. [default 0]." )
115
+ .SetDefault (0 );
103
116
}
104
117
};
105
118
} // namespace operators
@@ -109,8 +122,5 @@ namespace ops = paddle::operators;
109
122
REGISTER_OPERATOR (sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
110
123
paddle::framework::EmptyGradOpMaker);
111
124
112
- REGISTER_OP_CPU_KERNEL (
113
- sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int >,
114
- ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t >,
115
- ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float >,
116
- ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double >);
125
+ REGISTER_OP_CPU_KERNEL (sampling_id, paddle::operators::SamplingIdKernel<float >,
126
+ paddle::operators::SamplingIdKernel<double >);
0 commit comments