Skip to content

Commit c108376

Browse files
jerrywgzqingqing01
authored andcommitted
Add three modes for prelu_op (#12630)
* Add three modes for prelu_op.
1 parent d068493 commit c108376

File tree

7 files changed

+237
-101
lines changed

7 files changed

+237
-101
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul
159159
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
160160
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
161161
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
162+
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
162163
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
163164
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
164165
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))

paddle/fluid/operators/prelu_op.cc

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel {
2623
: OperatorWithKernel(type, inputs, outputs, attrs) {}
2724

2825
void InferShape(framework::InferShapeContext *ctx) const override {
26+
std::string mode = ctx->Attrs().Get<std::string>("mode");
27+
28+
auto x_dim = ctx->GetInputDim("X");
2929
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
3030
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
31-
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
32-
"Size of weight Alpha must be one.");
31+
3332
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
34-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
33+
if (mode == "all") {
34+
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
35+
"For mode 'all', size of weight Alpha must be one.");
36+
} else if (mode == "channel") {
37+
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1],
38+
"For channel-wise mode, size of weight Alpha must be "
39+
"equal to the number of channels, should be %d",
40+
x_dim[1]);
41+
} else if (mode == "element") {
42+
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim),
43+
"For element-wise mode, size of weight Alpha must be "
44+
"equal to the number of input, should be %d",
45+
product(x_dim));
46+
} else {
47+
PADDLE_THROW("Unkown mode %s", mode);
48+
}
49+
ctx->SetOutputDim("Out", x_dim);
3550
ctx->ShareLoD("X", /*->*/ "Out");
3651
}
52+
53+
protected:
54+
framework::OpKernelType GetExpectedKernelType(
55+
const framework::ExecutionContext &ctx) const override {
56+
return framework::OpKernelType(
57+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
58+
platform::CPUPlace());
59+
}
3760
};
3861

3962
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -44,21 +67,23 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
4467
AddOutput("Out", "The output tensor of prelu operator.");
4568
AddComment(R"DOC(
4669
PRelu Operator.
47-
4870
The equation is:
49-
5071
$$
5172
f(x) =
5273
\begin{cases}
5374
\alpha * x, \quad \text{if} \ x < 0 \\
5475
x, \qquad \text{if} \ x >= 0
5576
\end{cases}
5677
$$
57-
5878
The input `X` can carry the LoD (Level of Details) information,
5979
or not. And the output shares the LoD information with input `X`.
60-
80+
There are modes:
81+
all: all elements share same weight
82+
channel: elements in a channel share same weight
83+
element: each element has a weight
6184
)DOC");
85+
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
86+
.SetDefault("all");
6287
}
6388
};
6489

@@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel {
7196
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
7297
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
7398
"Input(Out@GRAD) should not be null");
74-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
75-
ctx->SetOutputDim(framework::GradVarName("Alpha"),
76-
ctx->GetInputDim("Alpha"));
99+
auto x_grad_name = framework::GradVarName("X");
100+
auto alpha_grad_name = framework::GradVarName("Alpha");
101+
102+
if (ctx->HasOutput(x_grad_name)) {
103+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
104+
}
105+
if (ctx->HasOutput(alpha_grad_name)) {
106+
ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha"));
107+
}
108+
}
109+
110+
protected:
111+
framework::OpKernelType GetExpectedKernelType(
112+
const framework::ExecutionContext &ctx) const override {
113+
return framework::OpKernelType(
114+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
115+
platform::CPUPlace());
77116
}
78117
};
79118

paddle/fluid/operators/prelu_op.cu

Lines changed: 0 additions & 22 deletions
This file was deleted.

paddle/fluid/operators/prelu_op.h

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,25 @@
11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129
See the License for the specific language governing permissions and
1310
limitations under the License. */
1411

1512
#pragma once
13+
#include <string>
1614
#include "paddle/fluid/framework/eigen.h"
1715
#include "paddle/fluid/framework/op_registry.h"
1816
#include "paddle/fluid/platform/transform.h"
19-
2017
namespace paddle {
2118
namespace operators {
2219

2320
using Tensor = framework::Tensor;
2421
using platform::Transform;
2522

26-
template <typename T>
27-
class PReluFunctor {
28-
public:
29-
explicit PReluFunctor(const T* alpha) : alpha_(alpha) {}
30-
31-
HOSTDEVICE T operator()(const T& x) const {
32-
if (x > 0)
33-
return x;
34-
else
35-
return x * (*alpha_);
36-
}
37-
38-
private:
39-
const T* alpha_;
40-
};
41-
4223
template <typename DeviceContext, typename T>
4324
class PReluKernel : public framework::OpKernel<T> {
4425
public:
@@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel<T> {
5031
const T* x_ptr = x->data<T>();
5132
T* o_ptr = out->mutable_data<T>(context.GetPlace());
5233

53-
auto* alpha_ptr = alpha->data<T>();
34+
const T* alpha_ptr = alpha->data<T>();
35+
std::string mode = context.Attr<std::string>("mode");
5436

5537
int numel = x->numel();
56-
57-
Transform<DeviceContext> trans;
58-
trans(context.template device_context<DeviceContext>(), x_ptr,
59-
x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr));
60-
}
61-
};
62-
63-
template <typename T>
64-
class PReluGradFunctor {
65-
public:
66-
explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {}
67-
68-
HOSTDEVICE T operator()(const T& out, const T& dout) const {
69-
if (out > 0)
70-
return dout;
71-
else
72-
return dout * (*alpha_);
38+
auto dim = x->dims();
39+
int index = 0;
40+
int i = 0;
41+
int temp = 0;
42+
if (mode == "channel") {
43+
for (i = 0; i < numel; i++) {
44+
temp = numel / (dim[0] * dim[1]);
45+
index = (i / temp) % dim[1];
46+
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
47+
}
48+
} else if (mode == "element") {
49+
for (i = 0; i < numel; i++) {
50+
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i];
51+
}
52+
} else {
53+
for (i = 0; i < numel; i++) {
54+
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i];
55+
}
56+
}
7357
}
74-
75-
private:
76-
const T* alpha_;
7758
};
7859

7960
template <typename DeviceContext, typename T>
8061
class PReluGradKernel : public framework::OpKernel<T> {
8162
public:
8263
void Compute(const framework::ExecutionContext& context) const override {
64+
auto* x = context.Input<Tensor>("X");
8365
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
8466
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
85-
67+
auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));
8668
auto* out = context.Input<Tensor>("Out");
8769
auto* alpha = context.Input<Tensor>("Alpha");
88-
auto* alpha_ptr = alpha->data<T>();
89-
90-
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
70+
const T* alpha_ptr = alpha->data<T>();
71+
const T* x_ptr = x->data<T>();
9172
const T* dout_ptr = dout->data<T>();
9273
const T* out_ptr = out->data<T>();
93-
int numel = dx->numel();
94-
95-
Transform<DeviceContext> trans;
96-
trans(context.template device_context<DeviceContext>(), out_ptr,
97-
out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor<T>(alpha_ptr));
98-
99-
// TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready
74+
std::string mode = context.Attr<std::string>("mode");
75+
int numel = x->numel();
76+
auto dim = x->dims();
77+
int index = 0;
78+
int i = 0;
79+
int temp = 0;
80+
if (dx) {
81+
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
82+
if (mode == "channel") {
83+
for (i = 0; i < numel; i++) {
84+
temp = numel / (dim[0] * dim[1]);
85+
index = (i / temp) % dim[1];
86+
dx_ptr[i] =
87+
out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
88+
}
89+
} else if (mode == "element") {
90+
for (i = 0; i < numel; i++) {
91+
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i];
92+
}
93+
} else {
94+
for (i = 0; i < numel; i++) {
95+
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i];
96+
}
97+
}
98+
}
99+
100+
index = 0;
101+
if (dalpha) {
102+
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
103+
if (mode == "channel") {
104+
for (i = 0; i < numel; i++) {
105+
temp = numel / (dim[0] * dim[1]);
106+
index = (i / temp) % dim[1];
107+
dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
108+
}
109+
} else if (mode == "element") {
110+
for (i = 0; i < numel; i++) {
111+
dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
112+
}
113+
} else {
114+
for (i = 0; i < numel; i++) {
115+
dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
116+
}
117+
}
118+
}
119+
120+
// TODO(Guanzhong): add GPU kernels
100121
}
101122
};
102123

python/paddle/fluid/layers/nn.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
'log',
113113
'crop',
114114
'rank_loss',
115+
'prelu',
115116
'flatten',
116117
]
117118

@@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None):
53645365
return out
53655366

53665367

5368+
def prelu(x, mode, param_attr=None, name=None):
5369+
"""
5370+
Equation:
5371+
5372+
y = \max(0, x) + alpha \min(0, x)
5373+
5374+
Args:
5375+
x (Variable): The input tensor.
5376+
param_attr(ParamAttr|None): The parameter attribute for the learnable
5377+
weight (alpha).
5378+
mode (string): The mode for weight sharing
5379+
all: all elements share same weight
5380+
channel:elements in a channel share same weight
5381+
element:each element has a weight
5382+
name(str|None): A name for this layer(optional). If set None, the layer
5383+
will be named automatically.
5384+
5385+
Returns:
5386+
Variable: The output tensor with the same shape as input.
5387+
5388+
Examples:
5389+
5390+
.. code-block:: python
5391+
5392+
x = fluid.layers.data(name="x", shape=[10,10], dtype="float32")
5393+
mode = 'channel'
5394+
output = fluid.layers.prelu(x,mode)
5395+
"""
5396+
helper = LayerHelper('prelu', **locals())
5397+
if mode not in ['all', 'channel', 'element']:
5398+
raise ValueError('mode should be one of all, channel, element.')
5399+
alpha_shape = [1]
5400+
if mode == 'channel':
5401+
alpha_shape = [1, x.shape[1], 1, 1]
5402+
elif mode == 'element':
5403+
alpha_shape = x.shape
5404+
dtype = helper.input_dtype(input_param_name='x')
5405+
alpha = helper.create_parameter(
5406+
attr=param_attr,
5407+
shape=alpha_shape,
5408+
dtype='float32',
5409+
is_bias=False,
5410+
default_initializer=Constant(1.0))
5411+
out = helper.create_tmp_variable(dtype)
5412+
helper.append_op(
5413+
type="prelu",
5414+
inputs={"X": x,
5415+
'Alpha': alpha},
5416+
attrs={"mode": mode},
5417+
outputs={"Out": out})
5418+
return out
5419+
5420+
53675421
def flatten(x, axis=1, name=None):
53685422
"""
53695423
**Flatten layer**

0 commit comments

Comments
 (0)