Skip to content

Commit e0ab2f7

Browse files
committed
new sampling op
1 parent 0964de1 commit e0ab2f7

File tree

3 files changed

+172
-0
lines changed

3 files changed

+172
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sampling_id_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
class SamplingIdOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext *ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of RowConvOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of RowConvOp should not be null.");
31+
32+
auto input_dims = ctx->GetInputDim("X");
33+
34+
framework::DDim dims = input_dims;
35+
ctx->SetOutputDim("Out", dims);
36+
ctx->ShareLoD("X", "Out");
37+
}
38+
};
39+
40+
class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
41+
public:
42+
void Make() override {
43+
AddInput("X",
44+
"The input tensor of softmax. "
45+
"2-D with shape [batch_size, input_feature_dimensions].");
46+
AddOutput("Out", "Sliced data tensor.");
47+
48+
AddComment(R"DOC(
49+
SamplingId Operator.
50+
@brief A layer for sampling id from multinomial distribution from the
51+
input layer. Sampling one id for one sample. The result is stored in
52+
output_.ids.
53+
)DOC");
54+
}
55+
};
56+
} // namespace operators
57+
} // namespace paddle
58+
59+
namespace ops = paddle::operators;
60+
REGISTER_OP_CUDA_KERNEL(
61+
slice, ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, float>,
62+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, double>,
63+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int>,
64+
ops::SamplingIdKernel<paddle::platform::CUDADeviceContext, int64_t>);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include <vector>
17+
#include "paddle/fluid/operators/sampling_id_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
class SamplingIdOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
void InferShape(framework::InferShapeContext *ctx) const override {}
28+
}
29+
} // namespace operators
30+
} // namespace paddle
31+
32+
namespace ops = paddle::operators;
33+
REGISTER_OPERATOR(samplingid, ops::SamplingIdOp, ops::SamplingIdOpMaker,
34+
paddle::framework::EmptyGradOpMaker);
35+
36+
REGISTER_OP_CPU_KERNEL(
37+
slice, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>,
38+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
39+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
40+
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#pragma once
15+
16+
#include <random>
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class SamplingIdKernel : public framework::OpKernel<T> {
25+
/// Produces random floating-point values, uniformly distributed on [0, 1).
26+
std::uniform_real_distribution<double> rand1_;
27+
28+
public:
29+
void Compute(const framework::ExecutionContext& context) const override {
30+
const Tensor* input = context.Input<Tensor>("X");
31+
const int batch_size = static_cast<int>(input->dims()[0]);
32+
const int width = static_cast<int>(input->dims()[1]);
33+
34+
std::vector<int> ids(batchSize);
35+
auto& reng = get();
36+
37+
for (size_t i = 0; i < batchSize; ++i) {
38+
double r = rand1_(reng);
39+
int id = dim - 1;
40+
for (int j = 0; j < dim; ++j) {
41+
if ((r -= buf[i * dim + j]) < 0) {
42+
id = j;
43+
break;
44+
}
45+
}
46+
ids[i] = id;
47+
}
48+
49+
std::vector<int64_t> out_dim;
50+
out_dim.push_back(static_cast<int64_t>(batch_size));
51+
52+
Tensor* output = context.Output<Tensor>("Output");
53+
output->Resize(framework::make_ddim(in_dim));
54+
output->mutable_data<T>(context.GetPlace());
55+
framework::TensorFromVector(ids, context.device_context(), output);
56+
}
57+
58+
std::default_random_engine& get() {
59+
auto engine = new std::default_random_engine;
60+
engine->seed(defaultSeed);
61+
return *engine;
62+
}
63+
64+
private:
65+
unsigned int defaultSeed = 0;
66+
}
67+
} // namespace operators
68+
} // namespace paddle

0 commit comments

Comments
 (0)