Skip to content

Commit bddfa21

Browse files
authored
add mish op. test=develop (#25341)
1 parent 38f9b71 commit bddfa21

File tree

7 files changed

+617
-0
lines changed

7 files changed

+617
-0
lines changed

paddle/fluid/operators/mish_op.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/mish_op.h"
13+
#include <memory>
14+
#include <string>
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
class MishOp : public framework::OperatorWithKernel {
20+
public:
21+
using framework::OperatorWithKernel::OperatorWithKernel;
22+
23+
void InferShape(framework::InferShapeContext *ctx) const override {
24+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
25+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mish");
26+
27+
ctx->ShareDim("X", /*->*/ "Out");
28+
ctx->ShareLoD("X", /*->*/ "Out");
29+
}
30+
31+
protected:
32+
framework::OpKernelType GetExpectedKernelType(
33+
const framework::ExecutionContext &ctx) const override {
34+
return framework::OpKernelType(
35+
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
36+
ctx.device_context());
37+
}
38+
};
39+
40+
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
41+
public:
42+
void Make() override {
43+
AddInput("X", "Input of Mish operator");
44+
AddOutput("Out", "Output of Mish operator");
45+
AddAttr<float>(
46+
"threshold",
47+
"Constant threshold of softplus in Mish operator. Approximate value "
48+
"of softplus will be used if absolute value of input is greater than "
49+
":attr:`threshold`")
50+
.SetDefault(20.f);
51+
AddComment(R"DOC(
52+
Mish Activation Operator.
53+
54+
.. math::
55+
softplus = \begin{cases}
56+
x, \text{if } x > \text{threshold} \\
57+
e^{x}, \text{if } x < -\text{threshold} \\
58+
\ln(1 + e^{x}), \text{otherwise}
59+
\end{cases}
60+
61+
out = x * \tanh(softplus)
62+
63+
)DOC");
64+
}
65+
};
66+
67+
// The operator to calculate gradients of a prelu operator.
68+
class MishGradOp : public framework::OperatorWithKernel {
69+
public:
70+
using framework::OperatorWithKernel::OperatorWithKernel;
71+
72+
void InferShape(framework::InferShapeContext *ctx) const override {
73+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mish");
74+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
75+
"Out@GRAD", "mish");
76+
77+
auto x_grad_name = framework::GradVarName("X");
78+
if (ctx->HasOutput(x_grad_name)) {
79+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
80+
}
81+
}
82+
83+
protected:
84+
framework::OpKernelType GetExpectedKernelType(
85+
const framework::ExecutionContext &ctx) const override {
86+
return framework::OpKernelType(
87+
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
88+
ctx.device_context());
89+
}
90+
};
91+
92+
template <typename T>
93+
class MishGradOpMaker : public framework::SingleGradOpMaker<T> {
94+
public:
95+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
96+
97+
protected:
98+
void Apply(GradOpPtr<T> op) const override {
99+
op->SetType("mish_grad");
100+
op->SetInput("X", this->Input("X"));
101+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
102+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
103+
op->SetAttrMap(this->Attrs());
104+
}
105+
};
106+
107+
} // namespace operators
108+
} // namespace paddle
109+
110+
namespace ops = paddle::operators;
111+
112+
REGISTER_OPERATOR(mish, ops::MishOp, ops::MishOpMaker,
113+
ops::MishGradOpMaker<paddle::framework::OpDesc>,
114+
ops::MishGradOpMaker<paddle::imperative::OpBase>);
115+
REGISTER_OPERATOR(mish_grad, ops::MishGradOp);
116+
REGISTER_OP_CPU_KERNEL(
117+
mish, ops::MishFP32CPUKernel<paddle::platform::CPUDeviceContext>,
118+
ops::MishCPUKernel<paddle::platform::CPUDeviceContext, double>);
119+
REGISTER_OP_CPU_KERNEL(
120+
mish_grad, ops::MishGradFP32CPUKernel<paddle::platform::CPUDeviceContext>,
121+
ops::MishGradCPUKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/mish_op.cu

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/framework/op_registry.h"
13+
#include "paddle/fluid/operators/mish_op.h"
14+
#include "paddle/fluid/platform/cuda_primitives.h"
15+
#include "paddle/fluid/platform/gpu_launch_config.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
template <typename T>
23+
__global__ void KeMishFw(const T* in, T* out, const int numel,
24+
const float threshold) {
25+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
26+
int stride = blockDim.x * gridDim.x;
27+
for (; tid < numel; tid += stride) {
28+
T x = in[tid];
29+
T sp = CalcSoftplus<T>(x, threshold);
30+
out[tid] = x * tanh(sp);
31+
}
32+
}
33+
34+
// expf instead of exp should be used for float type, complement
35+
// and register float kernel separatelly
36+
__global__ void KeMishFwFP32(const float* in, float* out, const int numel,
37+
const float threshold) {
38+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
39+
int stride = blockDim.x * gridDim.x;
40+
for (; tid < numel; tid += stride) {
41+
float x = in[tid];
42+
float sp = CalcSoftplusFP32(x, threshold);
43+
out[tid] = x * tanhf(sp);
44+
}
45+
}
46+
47+
template <typename T>
48+
__global__ void KeMishBw(const T* in, const T* dout, T* din, const int numel,
49+
const float threshold) {
50+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
51+
int stride = blockDim.x * gridDim.x;
52+
for (; tid < numel; tid += stride) {
53+
T x = in[tid];
54+
T sp = CalcSoftplus<T>(x, threshold);
55+
T tsp = tanh(sp);
56+
T grad_sp = -expm1(-sp);
57+
T grad_tsp = (static_cast<T>(1) - tsp * tsp) * grad_sp;
58+
din[tid] = dout[tid] * (x * grad_tsp + tsp);
59+
}
60+
}
61+
62+
__global__ void KeMishBwFP32(const float* in, const float* dout, float* din,
63+
const int numel, const float threshold) {
64+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
65+
int stride = blockDim.x * gridDim.x;
66+
for (; tid < numel; tid += stride) {
67+
float x = in[tid];
68+
float sp = CalcSoftplusFP32(x, threshold);
69+
float tsp = tanhf(sp);
70+
float grad_sp = -expm1f(-sp);
71+
float grad_tsp = (static_cast<float>(1) - tsp * tsp) * grad_sp;
72+
din[tid] = dout[tid] * (x * grad_tsp + tsp);
73+
}
74+
}
75+
76+
template <typename DeviceContext, typename T>
77+
class MishCUDAKernel : public framework::OpKernel<T> {
78+
public:
79+
void Compute(const framework::ExecutionContext& ctx) const override {
80+
auto* x = ctx.Input<Tensor>("X");
81+
auto* out = ctx.Output<Tensor>("Out");
82+
83+
const float threshold = ctx.Attr<float>("threshold");
84+
85+
const T* x_data = x->data<T>();
86+
T* out_data = out->mutable_data<T>(ctx.GetPlace());
87+
88+
const int numel = x->numel();
89+
90+
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
91+
KeMishFw<T><<<config.blocks, config.threads, 0,
92+
ctx.cuda_device_context().stream()>>>(x_data, out_data, numel,
93+
threshold);
94+
}
95+
};
96+
97+
template <typename DeviceContext>
98+
class MishFP32CUDAKernel : public framework::OpKernel<float> {
99+
public:
100+
void Compute(const framework::ExecutionContext& ctx) const override {
101+
auto* x = ctx.Input<Tensor>("X");
102+
auto* out = ctx.Output<Tensor>("Out");
103+
104+
const float threshold = ctx.Attr<float>("threshold");
105+
106+
const float* x_data = x->data<float>();
107+
float* out_data = out->mutable_data<float>(ctx.GetPlace());
108+
109+
const int numel = x->numel();
110+
111+
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
112+
KeMishFwFP32<<<config.blocks, config.threads, 0,
113+
ctx.cuda_device_context().stream()>>>(x_data, out_data,
114+
numel, threshold);
115+
}
116+
};
117+
118+
template <typename DeviceContext, typename T>
119+
class MishGradCUDAKernel : public framework::OpKernel<T> {
120+
public:
121+
void Compute(const framework::ExecutionContext& ctx) const override {
122+
auto* x = ctx.Input<Tensor>("X");
123+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
124+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
125+
126+
auto threshold = ctx.Attr<float>("threshold");
127+
128+
const T* x_data = x->data<T>();
129+
const T* dout_data = dout->data<T>();
130+
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
131+
132+
const int numel = x->numel();
133+
134+
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
135+
KeMishBw<T><<<config.blocks, config.threads, 0,
136+
ctx.cuda_device_context().stream()>>>(
137+
x_data, dout_data, dx_data, numel, threshold);
138+
}
139+
};
140+
141+
template <typename DeviceContext>
142+
class MishGradFP32CUDAKernel : public framework::OpKernel<float> {
143+
public:
144+
void Compute(const framework::ExecutionContext& ctx) const override {
145+
auto* x = ctx.Input<Tensor>("X");
146+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
147+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
148+
149+
auto threshold = ctx.Attr<float>("threshold");
150+
151+
const float* x_data = x->data<float>();
152+
const float* dout_data = dout->data<float>();
153+
float* dx_data = dx->mutable_data<float>(ctx.GetPlace());
154+
155+
const int numel = x->numel();
156+
157+
platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(numel, ctx);
158+
KeMishBwFP32<<<config.blocks, config.threads, 0,
159+
ctx.cuda_device_context().stream()>>>(
160+
x_data, dout_data, dx_data, numel, threshold);
161+
}
162+
};
163+
164+
} // namespace operators
165+
} // namespace paddle
166+
167+
namespace ops = paddle::operators;
168+
REGISTER_OP_CUDA_KERNEL(
169+
mish, ops::MishFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
170+
ops::MishCUDAKernel<paddle::platform::CUDADeviceContext, double>)
171+
REGISTER_OP_CUDA_KERNEL(
172+
mish_grad, ops::MishGradFP32CUDAKernel<paddle::platform::CUDADeviceContext>,
173+
ops::MishGradCUDAKernel<paddle::platform::CUDADeviceContext, double>)

0 commit comments

Comments
 (0)