|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +you may not use this file except in compliance with the License. |
| 4 | +You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +Unless required by applicable law or agreed to in writing, software |
| 7 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +See the License for the specific language governing permissions and |
| 10 | +limitations under the License. */ |
| 11 | + |
| 12 | +#include "paddle/fluid/operators/shuffle_channel_op.h" |
| 13 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 14 | +#include "paddle/fluid/platform/gpu_info.h" |
| 15 | + |
| 16 | +namespace paddle { |
| 17 | +namespace operators { |
| 18 | + |
| 19 | +using Tensor = framework::Tensor; |
| 20 | +static constexpr int kNumCUDAThreads = 512; |
| 21 | +static constexpr int kNumMaximumNumBlocks = 4096; |
| 22 | + |
| 23 | +static inline int NumBlocks(const int N) { |
| 24 | + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, |
| 25 | + kNumMaximumNumBlocks); |
| 26 | +} |
| 27 | + |
| 28 | +template <typename T> |
| 29 | +__global__ void ShuffleChannel(const int nthreads, const int feature_map_size, |
| 30 | + T* output, const T* input, int group_row, |
| 31 | + int group_column, int len) { |
| 32 | + int index = blockIdx.x * blockDim.x + threadIdx.x; |
| 33 | + int offset = blockDim.x * gridDim.x; |
| 34 | + for (size_t ii = index; ii < nthreads; ii += offset) { |
| 35 | + const int n = index / group_row / group_column / len; |
| 36 | + const int i = (index / group_column / len) % group_row; |
| 37 | + const int j = index / len % group_column; |
| 38 | + const int k = index - (n * feature_map_size + (i * group_column + j) * len); |
| 39 | + T* p_o = output + n * feature_map_size + (j * group_row + i) * len; |
| 40 | + p_o[k] = input[index]; |
| 41 | + } |
| 42 | +} |
| 43 | +template <typename DeviceContext, typename T> |
| 44 | +class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> { |
| 45 | + public: |
| 46 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 47 | + auto* input = ctx.Input<framework::Tensor>("X"); |
| 48 | + auto* output = ctx.Output<framework::Tensor>("Out"); |
| 49 | + int group = ctx.Attr<int>("group"); |
| 50 | + |
| 51 | + auto input_dims = input->dims(); |
| 52 | + auto num = input_dims[0]; |
| 53 | + auto channel = input_dims[1]; |
| 54 | + auto height = input_dims[2]; |
| 55 | + auto weight = input_dims[3]; |
| 56 | + |
| 57 | + auto feature_map_size = channel * height * weight; |
| 58 | + auto sp_sz = height * weight; |
| 59 | + int group_row = group; |
| 60 | + int group_column = channel / group_row; |
| 61 | + // count is the product of NCHW same as numel() |
| 62 | + int count = num * group_column * group_row * sp_sz; |
| 63 | + |
| 64 | + int blocks = NumBlocks(output->numel()); |
| 65 | + int threads = kNumCUDAThreads; |
| 66 | + |
| 67 | + const T* input_data = input->data<T>(); |
| 68 | + T* output_data = output->mutable_data<T>(ctx.GetPlace()); |
| 69 | + |
| 70 | + ShuffleChannel< |
| 71 | + T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( |
| 72 | + count, feature_map_size, output_data, input_data, group_row, |
| 73 | + group_column, sp_sz); |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +template <typename DeviceContext, typename T> |
| 78 | +class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> { |
| 79 | + public: |
| 80 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 81 | + auto* input = ctx.Input<framework::Tensor>("X"); |
| 82 | + int group = ctx.Attr<int>("group"); |
| 83 | + |
| 84 | + auto input_dims = input->dims(); |
| 85 | + auto num = input_dims[0]; |
| 86 | + auto channel = input_dims[1]; |
| 87 | + auto height = input_dims[2]; |
| 88 | + auto weight = input_dims[3]; |
| 89 | + auto feature_map_size = channel * height * weight; |
| 90 | + auto sp_sz = height * weight; |
| 91 | + |
| 92 | + int group_row = group; |
| 93 | + int group_column = channel / group_row; |
| 94 | + auto* output_grad = |
| 95 | + ctx.Input<framework::Tensor>(framework::GradVarName("Out")); |
| 96 | + auto* input_grad = |
| 97 | + ctx.Output<framework::Tensor>(framework::GradVarName("X")); |
| 98 | + T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); |
| 99 | + const T* output_grad_data = output_grad->data<T>(); |
| 100 | + |
| 101 | + int blocks = NumBlocks(output_grad->numel()); |
| 102 | + int threads = kNumCUDAThreads; |
| 103 | + int count = num * group_column * group_row * sp_sz; |
| 104 | + |
| 105 | + ShuffleChannel< |
| 106 | + T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( |
| 107 | + count, feature_map_size, input_grad_data, output_grad_data, group_row, |
| 108 | + group_column, sp_sz); |
| 109 | + } |
| 110 | +}; |
| 111 | +} // namespace operators |
| 112 | +} // namespace paddle |
| 113 | + |
| 114 | +namespace ops = paddle::operators; |
| 115 | +REGISTER_OP_CUDA_KERNEL( |
| 116 | + shuffle_channel, |
| 117 | + ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, float>, |
| 118 | + ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, |
| 119 | + double>); |
| 120 | +REGISTER_OP_CUDA_KERNEL( |
| 121 | + shuffle_channel_grad, |
| 122 | + ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext, |
| 123 | + float>, |
| 124 | + ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext, |
| 125 | + double>); |
0 commit comments