|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | + you may not use this file except in compliance with the License. |
| 5 | + You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | + Unless required by applicable law or agreed to in writing, software |
| 10 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + See the License for the specific language governing permissions and |
| 13 | + limitations under the License. */ |
| 14 | +#include <thrust/random.h> |
| 15 | +#include <thrust/transform.h> |
| 16 | +#include "paddle/fluid/framework/op_registry.h" |
| 17 | +#include "paddle/fluid/framework/operator.h" |
| 18 | + |
| 19 | +template <typename T> |
| 20 | +struct UniformGenerator { |
| 21 | + T min_, max_; |
| 22 | + unsigned int seed_; |
| 23 | + |
| 24 | + __host__ __device__ UniformGenerator(T min, T max, int seed) |
| 25 | + : min_(min), max_(max), seed_(seed) {} |
| 26 | + |
| 27 | + __host__ __device__ T operator()(const unsigned int n) const { |
| 28 | + thrust::minstd_rand rng; |
| 29 | + rng.seed(seed_); |
| 30 | + thrust::uniform_real_distribution<T> dist(min_, max_); |
| 31 | + rng.discard(n); |
| 32 | + return dist(rng); |
| 33 | + } |
| 34 | +}; |
| 35 | + |
| 36 | +namespace paddle { |
| 37 | +namespace operators { |
| 38 | + |
| 39 | +using Tensor = framework::Tensor; |
| 40 | + |
| 41 | +template <typename T> |
| 42 | +class SamplingIdKernel : public framework::OpKernel<T> { |
| 43 | + public: |
| 44 | + void Compute(const framework::ExecutionContext& context) const override { |
| 45 | + const Tensor* input = context.Input<Tensor>("X"); |
| 46 | + const int batch_size = static_cast<int>(input->dims()[0]); |
| 47 | + const int width = static_cast<int>(input->dims()[1]); |
| 48 | + |
| 49 | + std::vector<T> ins_vector; |
| 50 | + framework::TensorToVector(*input, context.device_context(), &ins_vector); |
| 51 | + |
| 52 | + unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); |
| 53 | + if (seed == 0) { |
| 54 | + std::random_device rd; |
| 55 | + seed = rd(); |
| 56 | + } |
| 57 | + T min = static_cast<T>(context.Attr<float>("min")); |
| 58 | + T max = static_cast<T>(context.Attr<float>("max")); |
| 59 | + |
| 60 | + std::vector<T> ids(batch_size); |
| 61 | + for (size_t i = 0; i < batch_size; ++i) { |
| 62 | + T r = UniformGenerator<T>(min, max, seed); |
| 63 | + int idx = width - 1; |
| 64 | + for (int j = 0; j < width; ++j) { |
| 65 | + if ((r -= ins_vector[i * width + j]) < 0) { |
| 66 | + idx = j; |
| 67 | + break; |
| 68 | + } |
| 69 | + } |
| 70 | + ids[i] = ins_vector[i * width + idx]; |
| 71 | + } |
| 72 | + |
| 73 | + std::vector<int64_t> out_dim; |
| 74 | + out_dim.push_back(static_cast<int64_t>(batch_size)); |
| 75 | + |
| 76 | + Tensor* output = context.Output<Tensor>("Out"); |
| 77 | + output->Resize(framework::make_ddim(out_dim)); |
| 78 | + output->mutable_data<T>(context.GetPlace()); |
| 79 | + framework::TensorFromVector(ids, context.device_context(), output); |
| 80 | + } |
| 81 | +}; |
| 82 | + |
| 83 | +} // namespace operators |
| 84 | +} // namespace paddle |
| 85 | + |
| 86 | +REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>, |
| 87 | + paddle::operators::SamplingIdKernel<double>); |
0 commit comments