Skip to content

Commit 4661f55

Browse files
committed
random optimize
1 parent 478f73c commit 4661f55

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel<T> {
3636
std::vector<T> ins_vector;
3737
framework::TensorToVector(*input, context.device_context(), &ins_vector);
3838

39+
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
40+
std::minstd_rand engine;
41+
if (seed == 0) {
42+
seed = std::random_device()();
43+
}
44+
engine.seed(seed);
45+
std::uniform_real_distribution<T> dist(
46+
static_cast<T>(ctx.Attr<float>("min")),
47+
static_cast<T>(ctx.Attr<float>("max")));
48+
3949
std::vector<T> ids(batch_size);
4050
for (size_t i = 0; i < batch_size; ++i) {
41-
double r = getRandReal();
51+
double r = dist(engine);
4252
int idx = width - 1;
4353
for (int j = 0; j < width; ++j) {
4454
if ((r -= ins_vector[i * width + j]) < 0) {
@@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
5767
output->mutable_data<T>(context.GetPlace());
5868
framework::TensorFromVector(ids, context.device_context(), output);
5969
}
60-
61-
private:
62-
double getRandReal() const {
63-
std::random_device
64-
rd; // Will be used to obtain a seed for the random number engine
65-
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with
66-
// rd()
67-
std::uniform_real_distribution<> dis(1.0, 2.0);
68-
return dis(gen);
69-
}
7070
};
7171

7272
class SamplingIdOp : public framework::OperatorWithKernel {
@@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
7878
"Input(X) of SamplingIdOp should not be null.");
7979
PADDLE_ENFORCE(ctx->HasOutput("Out"),
8080
"Output(Out) of SamplingIdOp should not be null.");
81+
PADDLE_ENFORCE(
82+
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
83+
"min must less then max");
8184

8285
auto input_dims = ctx->GetInputDim("X");
8386
PADDLE_ENFORCE(input_dims.size() == 2,
@@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
99102
AddComment(R"DOC(
100103
SamplingId Operator.
101104
A layer for sampling id from multinomial distribution from the
102-
input layer. Sampling one id for one sample.)DOC");
105+
input. Sampling one id for one sample.)DOC");
106+
AddAttr<float>("min", "Minimum value of random. [default 0.0].")
107+
.SetDefault(0.0f);
108+
AddAttr<float>("max", "Maximun value of random. [default 1.0].")
109+
.SetDefault(1.0f);
110+
AddAttr<int>("seed",
111+
"Random seed used for the random number engine. "
112+
"0 means use a seed generated by the system."
113+
"Note that if seed is not 0, this operator will always "
114+
"generate the same random numbers every time. [default 0].")
115+
.SetDefault(0);
103116
}
104117
};
105118
} // namespace operators
@@ -109,8 +122,5 @@ namespace ops = paddle::operators;
109122
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
110123
paddle::framework::EmptyGradOpMaker);
111124

112-
REGISTER_OP_CPU_KERNEL(
113-
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
114-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
115-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
116-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
125+
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
126+
paddle::operators::SamplingIdKernel<double>);

0 commit comments

Comments
 (0)