|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <thrust/execution_policy.h> |
| 16 | +#include <thrust/sort.h> |
| 17 | +#include "paddle/fluid/framework/op_registry.h" |
| 18 | +#include "paddle/fluid/operators/argsort_op.h" |
| 19 | +#include "paddle/fluid/platform/assert.h" |
| 20 | +#include "paddle/fluid/platform/cuda_device_function.h" |
| 21 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 22 | + |
| 23 | +namespace paddle { |
| 24 | +namespace operators { |
| 25 | + |
| 26 | +using Tensor = framework::Tensor; |
| 27 | +using platform::PADDLE_CUDA_NUM_THREADS; |
| 28 | + |
| 29 | +const int kMaxRank = 9; // The max rank of a tensor allowed in Fluid |
| 30 | + |
| 31 | +__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, |
| 32 | + int axis, int64_t n, int64_t* trg_idx, |
| 33 | + int64_t* med_ids) { |
| 34 | + int64_t index = threadIdx.x + blockDim.x * blockIdx.x; |
| 35 | + if (index < n) { |
| 36 | + int64_t shape_out_axis[kMaxRank - 1] = {0}; |
| 37 | + int64_t dims_out_axis[kMaxRank - 1] = {0}; |
| 38 | + int64_t tmp = index; |
| 39 | + int64_t pos_in_axis = 0; |
| 40 | + int64_t i = dims_size - 2; |
| 41 | + int64_t dim_axis = 0; |
| 42 | + for (int64_t j = dims_size - 1; j >= 0; --j) { |
| 43 | + int64_t dim = in_dims[j]; |
| 44 | + if (j != axis) { |
| 45 | + shape_out_axis[i] = tmp % dim; |
| 46 | + dims_out_axis[i] = dim; |
| 47 | + i--; |
| 48 | + } else { |
| 49 | + dim_axis = dim; |
| 50 | + pos_in_axis = tmp % dim_axis; |
| 51 | + } |
| 52 | + tmp /= dim; |
| 53 | + } |
| 54 | + int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; |
| 55 | + for (int64_t j = 0; j < dims_size - 2; ++j) { |
| 56 | + group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; |
| 57 | + } |
| 58 | + |
| 59 | + int64_t traget_idx = group * dim_axis + pos_in_axis; |
| 60 | + trg_idx[index] = traget_idx; |
| 61 | + med_ids[traget_idx] = pos_in_axis; |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +template <typename T> |
| 66 | +__global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, |
| 67 | + T* med_out) { |
| 68 | + int index = threadIdx.x + blockDim.x * blockIdx.x; |
| 69 | + if (index < n) { |
| 70 | + med_out[trg_idx[index]] = in[index]; |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +template <typename T> |
| 75 | +__global__ void Sort(int64_t axis_dim, int64_t groups, T* med_out, |
| 76 | + int64_t* med_ids) { |
| 77 | + int index = threadIdx.x + blockDim.x * blockIdx.x; |
| 78 | + if (index < groups) { |
| 79 | + thrust::sort_by_key(thrust::device, med_out + index * axis_dim, |
| 80 | + med_out + axis_dim * (1 + index), |
| 81 | + med_ids + index * axis_dim); |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +template <typename T> |
| 86 | +__global__ void PermuteMediateData(const T* med_out, const int64_t* med_ids, |
| 87 | + const int64_t* trg_idx, int64_t n, T* out, |
| 88 | + int64_t* indices) { |
| 89 | + int index = threadIdx.x + blockDim.x * blockIdx.x; |
| 90 | + if (index < n) { |
| 91 | + out[index] = med_out[trg_idx[index]]; |
| 92 | + indices[index] = med_ids[trg_idx[index]]; |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +template <typename T> |
| 97 | +class ArgsortOpCUDAKernel : public framework::OpKernel<T> { |
| 98 | + public: |
| 99 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 100 | + auto* input = ctx.Input<Tensor>("X"); |
| 101 | + auto* output = ctx.Output<Tensor>("Out"); |
| 102 | + auto* indices = ctx.Output<Tensor>("Indices"); |
| 103 | + int axis = ctx.Attr<int>("axis"); |
| 104 | + |
| 105 | + auto in_dims = input->dims(); |
| 106 | + axis = (axis < 0) ? (in_dims.size() + axis) : axis; |
| 107 | + |
| 108 | + const T* in_data = input->data<T>(); |
| 109 | + T* out_data = output->mutable_data<T>(ctx.GetPlace()); |
| 110 | + int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace()); |
| 111 | + |
| 112 | + int64_t numel = input->numel(); |
| 113 | + int64_t groups = numel / in_dims[axis]; |
| 114 | + |
| 115 | + std::vector<int64_t> in_dims_vec = vectorize(in_dims); |
| 116 | + thrust::device_vector<int64_t> in_dims_dev(in_dims_vec.begin(), |
| 117 | + in_dims_vec.end()); |
| 118 | + int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); |
| 119 | + // Mediate tensor for sorting data and indices |
| 120 | + Tensor mediate_output, mediate_indices; |
| 121 | + T* med_out_data = |
| 122 | + mediate_output.mutable_data<T>(input->dims(), ctx.GetPlace()); |
| 123 | + int64_t* med_ids_data = |
| 124 | + mediate_indices.mutable_data<int64_t>(in_dims, ctx.GetPlace()); |
| 125 | + // Target index of each element along the given axis in the mediate tensors |
| 126 | + Tensor trg_idx_t; |
| 127 | + int64_t* trg_idx = trg_idx_t.mutable_data<int64_t>(in_dims, ctx.GetPlace()); |
| 128 | + |
| 129 | + auto stream = ctx.cuda_device_context().stream(); |
| 130 | + const int num_threads = PADDLE_CUDA_NUM_THREADS; |
| 131 | + |
| 132 | + ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( |
| 133 | + in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); |
| 134 | + |
| 135 | + PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( |
| 136 | + in_data, trg_idx, numel, med_out_data); |
| 137 | + |
| 138 | + Sort<<<(groups - 1) / num_threads + 1, num_threads, 0, stream>>>( |
| 139 | + in_dims[axis], groups, med_out_data, med_ids_data); |
| 140 | + |
| 141 | + PermuteMediateData<<<(numel - 1) / num_threads + 1, num_threads, 0, |
| 142 | + stream>>>(med_out_data, med_ids_data, trg_idx, numel, |
| 143 | + out_data, ids_data); |
| 144 | + } |
| 145 | +}; |
| 146 | + |
| 147 | +} // namespace operators |
| 148 | +} // namespace paddle |
| 149 | + |
| 150 | +REGISTER_OP_CUDA_KERNEL(argsort, paddle::operators::ArgsortOpCUDAKernel<float>, |
| 151 | + paddle::operators::ArgsortOpCUDAKernel<double>); |
0 commit comments