|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +you may not use this file except in compliance with the License. |
| 4 | +You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +Unless required by applicable law or agreed to in writing, software |
| 7 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +See the License for the specific language governing permissions and |
| 10 | +limitations under the License. */ |
| 11 | + |
| 12 | +#include "paddle/fluid/framework/op_registry.h" |
| 13 | +#include "paddle/fluid/operators/mish_op.h" |
| 14 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 15 | +#include "paddle/fluid/platform/gpu_launch_config.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +using Tensor = framework::Tensor; |
| 21 | + |
| 22 | +template <typename T> |
| 23 | +__global__ void KeMishFw(const T* in, T* out, const int numel, |
| 24 | + const float threshold) { |
| 25 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 26 | + int stride = blockDim.x * gridDim.x; |
| 27 | + for (; tid < numel; tid += stride) { |
| 28 | + T x = in[tid]; |
| 29 | + T sp = CalcSoftplus<T>(x, threshold); |
| 30 | + out[tid] = x * tanh(sp); |
| 31 | + } |
| 32 | +} |
| 33 | + |
| 34 | +// expf instead of exp should be used for float type, complement |
| 35 | +// and register float kernel separatelly |
| 36 | +__global__ void KeMishFwFP32(const float* in, float* out, const int numel, |
| 37 | + const float threshold) { |
| 38 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 39 | + int stride = blockDim.x * gridDim.x; |
| 40 | + for (; tid < numel; tid += stride) { |
| 41 | + float x = in[tid]; |
| 42 | + float sp = CalcSoftplusFP32(x, threshold); |
| 43 | + out[tid] = x * tanhf(sp); |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +template <typename T> |
| 48 | +__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel, |
| 49 | + const float threshold) { |
| 50 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 51 | + int stride = blockDim.x * gridDim.x; |
| 52 | + for (; tid < numel; tid += stride) { |
| 53 | + T x = in[tid]; |
| 54 | + T sp = CalcSoftplus<T>(x, threshold); |
| 55 | + T tsp = tanh(sp); |
| 56 | + T grad_sp = -expm1(-sp); |
| 57 | + T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp; |
| 58 | + din[tid] = dout[tid] * (x * grad_tsp + tsp); |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +__global__ void KeMishBwFP32(const float* in, const float* dout, float* din, |
| 63 | + const int numel, const float threshold) { |
| 64 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 65 | + int stride = blockDim.x * gridDim.x; |
| 66 | + for (; tid < numel; tid += stride) { |
| 67 | + float x = in[tid]; |
| 68 | + float sp = CalcSoftplusFP32(x, threshold); |
| 69 | + float tsp = tanhf(sp); |
| 70 | + float grad_sp = -expm1f(-sp); |
| 71 | + float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp; |
| 72 | + din[tid] = dout[tid] * (x * grad_tsp + tsp); |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +template <typename DeviceContext, typename T> |
| 77 | +class MishCUDAKernel : public framework::OpKernel<T> { |
| 78 | + public: |
| 79 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 80 | + auto* x = ctx.Input<Tensor>("X"); |
| 81 | + auto* out = ctx.Output<Tensor>("Out"); |
| 82 | + |
| 83 | + const float threshold = ctx.Attr<float>("threshold"); |
| 84 | + |
| 85 | + const T* x_data = x->data<T>(); |
| 86 | + T* out_data = out->mutable_data<T>(ctx.GetPlace()); |
| 87 | + |
| 88 | + const int numel = x->numel(); |
| 89 | + |
| 90 | + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); |
| 91 | + KeMishFw<T><<<config.blocks, config.threads, 0, |
| 92 | + ctx.cuda_device_context().stream()>>>(x_data, out_data, numel, |
| 93 | + threshold); |
| 94 | + } |
| 95 | +}; |
| 96 | + |
| 97 | +template <typename DeviceContext> |
| 98 | +class MishFP32CUDAKernel : public framework::OpKernel<float> { |
| 99 | + public: |
| 100 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 101 | + auto* x = ctx.Input<Tensor>("X"); |
| 102 | + auto* out = ctx.Output<Tensor>("Out"); |
| 103 | + |
| 104 | + const float threshold = ctx.Attr<float>("threshold"); |
| 105 | + |
| 106 | + const float* x_data = x->data<float>(); |
| 107 | + float* out_data = out->mutable_data<float>(ctx.GetPlace()); |
| 108 | + |
| 109 | + const int numel = x->numel(); |
| 110 | + |
| 111 | + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); |
| 112 | + KeMishFwFP32<<<config.blocks, config.threads, 0, |
| 113 | + ctx.cuda_device_context().stream()>>>(x_data, out_data, |
| 114 | + numel, threshold); |
| 115 | + } |
| 116 | +}; |
| 117 | + |
| 118 | +template <typename DeviceContext, typename T> |
| 119 | +class MishGradCUDAKernel : public framework::OpKernel<T> { |
| 120 | + public: |
| 121 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 122 | + auto* x = ctx.Input<Tensor>("X"); |
| 123 | + auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 124 | + auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 125 | + |
| 126 | + auto threshold = ctx.Attr<float>("threshold"); |
| 127 | + |
| 128 | + const T* x_data = x->data<T>(); |
| 129 | + const T* dout_data = dout->data<T>(); |
| 130 | + T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); |
| 131 | + |
| 132 | + const int numel = x->numel(); |
| 133 | + |
| 134 | + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); |
| 135 | + KeMishBw<T><<<config.blocks, config.threads, 0, |
| 136 | + ctx.cuda_device_context().stream()>>>( |
| 137 | + x_data, dout_data, dx_data, numel, threshold); |
| 138 | + } |
| 139 | +}; |
| 140 | + |
| 141 | +template <typename DeviceContext> |
| 142 | +class MishGradFP32CUDAKernel : public framework::OpKernel<float> { |
| 143 | + public: |
| 144 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 145 | + auto* x = ctx.Input<Tensor>("X"); |
| 146 | + auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 147 | + auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 148 | + |
| 149 | + auto threshold = ctx.Attr<float>("threshold"); |
| 150 | + |
| 151 | + const float* x_data = x->data<float>(); |
| 152 | + const float* dout_data = dout->data<float>(); |
| 153 | + float* dx_data = dx->mutable_data<float>(ctx.GetPlace()); |
| 154 | + |
| 155 | + const int numel = x->numel(); |
| 156 | + |
| 157 | + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx); |
| 158 | + KeMishBwFP32<<<config.blocks, config.threads, 0, |
| 159 | + ctx.cuda_device_context().stream()>>>( |
| 160 | + x_data, dout_data, dx_data, numel, threshold); |
| 161 | + } |
| 162 | +}; |
| 163 | + |
| 164 | +} // namespace operators |
| 165 | +} // namespace paddle |
| 166 | + |
| 167 | +namespace ops = paddle::operators; |
| 168 | +REGISTER_OP_CUDA_KERNEL( |
| 169 | + mish, ops::MishFP32CUDAKernel<paddle::platform::CUDADeviceContext>, |
| 170 | + ops::MishCUDAKernel<paddle::platform::CUDADeviceContext, double>) |
| 171 | +REGISTER_OP_CUDA_KERNEL( |
| 172 | + mish_grad, ops::MishGradFP32CUDAKernel<paddle::platform::CUDADeviceContext>, |
| 173 | + ops::MishGradCUDAKernel<paddle::platform::CUDADeviceContext, double>) |
0 commit comments