Skip to content

Commit 4973e07

Browse files
committed
sampling op optimize
1 parent 3206970 commit 4973e07

File tree

3 files changed

+34
-30
lines changed

3 files changed

+34
-30
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ SamplingId Operator.
5757
} // namespace paddle
5858

5959
namespace ops = paddle::operators;
60-
REGISTER_OP_CUDA_KERNEL(
61-
sampling_id,
62-
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
63-
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
64-
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
65-
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);
60+
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
61+
paddle::framework::EmptyGradOpMaker);
62+
63+
REGISTER_OP_CPU_KERNEL(
64+
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
65+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
66+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
67+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/sampling_id_op.cu

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
3030
} // namespace paddle
3131

3232
namespace ops = paddle::operators;
33-
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
34-
paddle::framework::EmptyGradOpMaker);
35-
36-
REGISTER_OP_CPU_KERNEL(
37-
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
38-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
39-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
40-
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
33+
REGISTER_OP_CUDA_KERNEL(
34+
sampling_id,
35+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
36+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
37+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
38+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);

paddle/fluid/operators/sampling_id_op.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,31 @@ limitations under the License. */
1515

1616
#include <random>
1717
#include <vector>
18+
#include "paddle/fluid/framework/lod_tensor.h"
1819
#include "paddle/fluid/framework/op_registry.h"
1920

2021
namespace paddle {
2122
namespace operators {
2223

24+
using Tensor = framework::Tensor;
25+
2326
template <typename DeviceContext, typename T>
2427
class SamplingIdKernel : public framework::OpKernel<T> {
25-
/// Produces random floating-point values, uniformly distributed on [0, 1).
26-
std::uniform_real_distribution<double> rand1_;
27-
2828
public:
2929
void Compute(const framework::ExecutionContext& context) const override {
3030
const Tensor* input = context.Input<Tensor>("X");
3131
const int batch_size = static_cast<int>(input->dims()[0]);
3232
const int width = static_cast<int>(input->dims()[1]);
3333

34-
std::vector<int> ids(batchSize);
35-
auto& reng = get();
34+
std::vector<T> ins_vector;
35+
framework::TensorToVector(*input, context.device_context(), &ins_vector);
3636

37-
for (size_t i = 0; i < batchSize; ++i) {
38-
double r = rand1_(reng);
39-
int id = dim - 1;
40-
for (int j = 0; j < dim; ++j) {
41-
if ((r -= buf[i * dim + j]) < 0) {
37+
std::vector<int> ids(batch_size);
38+
for (size_t i = 0; i < batch_size; ++i) {
39+
double r = this->get_rand();
40+
int id = width - 1;
41+
for (int j = 0; j < width; ++j) {
42+
if ((r -= ins_vector[i * width + j]) < 0) {
4243
id = j;
4344
break;
4445
}
@@ -50,19 +51,22 @@ class SamplingIdKernel : public framework::OpKernel<T> {
5051
out_dim.push_back(static_cast<int64_t>(batch_size));
5152

5253
Tensor* output = context.Output<Tensor>("Output");
53-
output->Resize(framework::make_ddim(in_dim));
54+
output->Resize(framework::make_ddim(out_dim));
5455
output->mutable_data<T>(context.GetPlace());
5556
framework::TensorFromVector(ids, context.device_context(), output);
5657
}
5758

58-
std::default_random_engine& get() {
59-
auto engine = new std::default_random_engine;
60-
engine->seed(defaultSeed);
61-
return *engine;
59+
double get_rand() const {
60+
// Will be used to obtain a seed for the random number engine
61+
std::random_device rd;
62+
// Standard mersenne_twister_engine seeded with rd()
63+
std::mt19937 gen(rd());
64+
std::uniform_real_distribution<> dis(0, 1);
65+
return dis(gen);
6266
}
6367

6468
private:
6569
unsigned int defaultSeed = 0;
66-
}
70+
};
6771
} // namespace operators
6872
} // namespace paddle

0 commit comments

Comments
 (0)