Skip to content

Commit 0749c88

Browse files
authored
Merge pull request #12556 from seiriosPlus/samplingIdOp
Sampling id op
2 parents 0abfbd1 + 822496f commit 0749c88

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sampling_id_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
class SamplingIdOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext* ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of SamplingIdOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of SamplingIdOp should not be null.");
31+
PADDLE_ENFORCE(
32+
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
33+
"min must less then max");
34+
35+
auto input_dims = ctx->GetInputDim("X");
36+
PADDLE_ENFORCE(input_dims.size() == 2,
37+
"Input(X, Filter) should be 2-D tensor.");
38+
39+
framework::DDim dims = input_dims;
40+
ctx->SetOutputDim("Out", dims);
41+
ctx->ShareLoD("X", "Out");
42+
}
43+
};
44+
45+
class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
46+
public:
47+
void Make() override {
48+
AddInput("X",
49+
"The input tensor of softmax. "
50+
"2-D with shape [batch_size, input_feature_dimensions].");
51+
AddOutput("Out", "SamplingId data tensor.");
52+
AddComment(R"DOC(
53+
SamplingId Operator.
54+
A layer for sampling id from multinomial distribution from the
55+
input. Sampling one id for one sample.)DOC");
56+
AddAttr<float>("min", "Minimum value of random. [default 0.0].")
57+
.SetDefault(0.0f);
58+
AddAttr<float>("max", "Maximun value of random. [default 1.0].")
59+
.SetDefault(1.0f);
60+
AddAttr<int>("seed",
61+
"Random seed used for the random number engine. "
62+
"0 means use a seed generated by the system."
63+
"Note that if seed is not 0, this operator will always "
64+
"generate the same random numbers every time. [default 0].")
65+
.SetDefault(0);
66+
}
67+
};
68+
} // namespace operators
69+
} // namespace paddle
70+
71+
namespace ops = paddle::operators;
72+
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
73+
paddle::framework::EmptyGradOpMaker);
74+
75+
REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
76+
paddle::operators::SamplingIdKernel<double>);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sampling_id_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
19+
paddle::operators::SamplingIdKernel<double>);
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <iostream>
19+
#include <iterator>
20+
#include <random>
21+
#include <sstream>
22+
#include <vector>
23+
24+
#include "paddle/fluid/framework/op_registry.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
using Tensor = framework::Tensor;
30+
31+
template <typename T>
32+
class SamplingIdKernel : public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& context) const override {
35+
const Tensor* input = context.Input<Tensor>("X");
36+
const int batch_size = static_cast<int>(input->dims()[0]);
37+
const int width = static_cast<int>(input->dims()[1]);
38+
39+
PADDLE_ENFORCE_GE(batch_size, 0,
40+
"batch_size(dims[0]) must be nonnegative.");
41+
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
42+
43+
std::vector<T> ins_vector;
44+
framework::TensorToVector(*input, context.device_context(), &ins_vector);
45+
46+
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
47+
std::minstd_rand engine;
48+
if (seed == 0) {
49+
seed = std::random_device()();
50+
}
51+
engine.seed(seed);
52+
std::uniform_real_distribution<T> dist(
53+
static_cast<T>(context.Attr<float>("min")),
54+
static_cast<T>(context.Attr<float>("max")));
55+
56+
std::vector<T> ids(batch_size);
57+
for (size_t i = 0; i < batch_size; ++i) {
58+
T r = dist(engine);
59+
int idx = width - 1;
60+
for (int j = 0; j < width; ++j) {
61+
if ((r -= ins_vector[i * width + j]) < 0) {
62+
idx = j;
63+
break;
64+
}
65+
}
66+
ids[i] = ins_vector[i * width + idx];
67+
}
68+
69+
std::vector<int64_t> out_dim;
70+
out_dim.push_back(static_cast<int64_t>(batch_size));
71+
72+
Tensor* output = context.Output<Tensor>("Out");
73+
output->Resize(framework::make_ddim(out_dim));
74+
output->mutable_data<T>(context.GetPlace());
75+
framework::TensorFromVector(ids, context.device_context(), output);
76+
}
77+
};
78+
79+
} // namespace operators
80+
} // namespace paddle
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
import paddle.fluid.core as core
20+
from paddle.fluid.op import Operator
21+
22+
23+
class TestSamplingIdOp(OpTest):
24+
def setUp(self):
25+
self.op_type = "sampling_id"
26+
self.use_mkldnn = False
27+
self.init_kernel_type()
28+
self.X = np.random.random((8, 4)).astype('float32')
29+
self.inputs = {"X": self.X}
30+
self.Y = np.random.random(8).astype('float32')
31+
self.outputs = {'Out': self.Y}
32+
self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1}
33+
34+
def test_check_output(self):
35+
self.check_output_customized(self.verify_output)
36+
y1 = self.out
37+
self.check_output_customized(self.verify_output)
38+
y2 = self.out
39+
self.assertTrue(np.array_equal(y1, y2))
40+
self.assertEqual(len(y1), len(self.Y))
41+
42+
def verify_output(self, outs):
43+
out = np.array(outs[0])
44+
self.out = out
45+
46+
def init_kernel_type(self):
47+
pass
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()

0 commit comments

Comments
 (0)