Skip to content

Commit 67a2b52

Browse files
authored
Add affine channel op to speed and save memory for faster-rcnn model. (#13919)
* Add affine channel op. * Update code and add Python API. test=develop * Update API.spec test=develop
1 parent 30dfbde commit 67a2b52

File tree

6 files changed

+592
-0
lines changed

6 files changed

+592
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
173173
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
174174
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
175175
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
176+
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
176177
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
177178
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
178179
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ if (WITH_GPU)
305305
op_library(conv_op DEPS vol2col depthwise_conv im2col)
306306
op_library(layer_norm_op DEPS cub)
307307
op_library(reduce_mean_op DEPS cub)
308+
op_library(affine_channel_op DEPS cub)
308309
else()
309310
op_library(conv_op DEPS vol2col im2col)
310311
endif()
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
Indicesou may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/data_layout.h"
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class AffineChannelOpMaker : public framework::OpProtoAndCheckerMaker {
23+
public:
24+
void Make() override {
25+
AddInput("X",
26+
"(Tensor) Feature map input can be a 4D tensor with order NCHW "
27+
"or NHWC. It also can be a 2D tensor and C is the second "
28+
"dimension.");
29+
AddInput("Scale",
30+
"(Tensor) 1D input of shape (C), the c-th element "
31+
"is the scale factor of the affine transformation "
32+
"for the c-th channel of the input.");
33+
AddInput("Bias",
34+
"(Tensor) 1D input of shape (C), the c-th element "
35+
"is the bias of the affine transformation for the "
36+
"c-th channel of the input.");
37+
AddAttr<std::string>(
38+
"data_layout",
39+
"(string, default NCHW) Only used in "
40+
"An optional string from: \"NHWC\", \"NCHW\". "
41+
"Defaults to \"NHWC\". Specify the data format of the output data, "
42+
"the input will be transformed automatically. ")
43+
.SetDefault("AnyLayout");
44+
AddOutput("Out", "(Tensor) A tensor of the same shape and order with X.");
45+
AddComment(R"DOC(
46+
47+
Applies a separate affine transformation to each channel of the input. Useful
48+
for replacing spatial batch norm with its equivalent fixed transformation.
49+
The input also can be 2D tensor and applies a affine transformation in second
50+
dimension.
51+
52+
$$Out = Scale*X + Bias$$
53+
54+
)DOC");
55+
}
56+
};
57+
58+
class AffineChannelOp : public framework::OperatorWithKernel {
59+
public:
60+
using framework::OperatorWithKernel::OperatorWithKernel;
61+
void InferShape(framework::InferShapeContext* ctx) const override {
62+
PADDLE_ENFORCE(ctx->HasInput("X"),
63+
"Input(X) of AffineChannelOp should not be null.");
64+
PADDLE_ENFORCE(ctx->HasInput("Scale"),
65+
"Input(Scale) of AffineChannelOp should not be null.");
66+
PADDLE_ENFORCE(ctx->HasInput("Bias"),
67+
"Input(Bias) of AffineChannelOp should not be null.");
68+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
69+
"Output(Out) of AffineChannelOp should not be null.");
70+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
71+
ctx->ShareLoD("X", "Out");
72+
}
73+
};
74+
75+
class AffineChannelOpGrad : public framework::OperatorWithKernel {
76+
public:
77+
using framework::OperatorWithKernel::OperatorWithKernel;
78+
void InferShape(framework::InferShapeContext* ctx) const override {
79+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
80+
"Input(Out@GRAD) should not be null.");
81+
if (ctx->HasOutput(framework::GradVarName("X"))) {
82+
PADDLE_ENFORCE(ctx->HasInput("Scale"),
83+
"Input(Scale) should not be null.");
84+
ctx->SetOutputDim(framework::GradVarName("X"),
85+
ctx->GetInputDim(framework::GradVarName("Out")));
86+
}
87+
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
88+
// Scale@GRAD and Bias@GRAD must exist at the same time.
89+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
90+
"Output(Scale@GRAD) should not be null.");
91+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
92+
ctx->SetOutputDim(framework::GradVarName("Scale"),
93+
ctx->GetInputDim("Scale"));
94+
ctx->SetOutputDim(framework::GradVarName("Bias"),
95+
ctx->GetInputDim("Scale"));
96+
}
97+
}
98+
};
99+
100+
template <typename T>
101+
using EigenArrayMap =
102+
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
103+
template <typename T>
104+
using ConstEigenArrayMap =
105+
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
106+
template <typename T>
107+
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
108+
template <typename T>
109+
using ConstEigenVectorArrayMap =
110+
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
111+
112+
template <typename DeviceContext, typename T>
113+
class AffineChannelKernel : public framework::OpKernel<T> {
114+
public:
115+
void Compute(const framework::ExecutionContext& ctx) const override {
116+
auto* x = ctx.Input<framework::Tensor>("X");
117+
auto* scale = ctx.Input<framework::Tensor>("Scale");
118+
auto* bias = ctx.Input<framework::Tensor>("Bias");
119+
120+
auto* y = ctx.Output<framework::Tensor>("Out");
121+
y->mutable_data<T>(ctx.GetPlace());
122+
123+
const framework::DataLayout layout =
124+
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
125+
126+
auto dims = x->dims();
127+
int N = dims[0];
128+
int C = layout == framework::DataLayout::kNCHW ? dims[1]
129+
: dims[dims.size() - 1];
130+
int HxW = x->numel() / N / C;
131+
132+
auto* scale_d = scale->data<T>();
133+
auto* bias_d = bias->data<T>();
134+
ConstEigenVectorArrayMap<T> a_e(scale_d, C);
135+
ConstEigenVectorArrayMap<T> b_e(bias_d, C);
136+
137+
auto* x_d = x->data<T>();
138+
auto* y_d = y->data<T>();
139+
if (layout == framework::DataLayout::kNCHW) {
140+
int stride = C * HxW;
141+
for (int i = 0; i < N; i++) {
142+
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
143+
EigenArrayMap<T> y_e(y_d, HxW, C);
144+
y_e = (x_e.rowwise() * a_e.transpose()).rowwise() + b_e.transpose();
145+
x_d += stride;
146+
y_d += stride;
147+
}
148+
} else {
149+
int num = N * HxW;
150+
ConstEigenArrayMap<T> x_e(x_d, C, num);
151+
EigenArrayMap<T> y_e(y_d, C, num);
152+
y_e = (x_e.colwise() * a_e).colwise() + b_e;
153+
}
154+
}
155+
};
156+
157+
template <typename DeviceContext, typename T>
158+
class AffineChannelGradKernel : public framework::OpKernel<T> {
159+
public:
160+
void Compute(const framework::ExecutionContext& ctx) const override {
161+
auto* x = ctx.Input<framework::Tensor>("X");
162+
auto* scale = ctx.Input<framework::Tensor>("Scale");
163+
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
164+
165+
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
166+
auto* dscale =
167+
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
168+
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
169+
170+
const framework::DataLayout layout =
171+
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
172+
173+
auto dims = x->dims();
174+
int N = dims[0];
175+
int C = layout == framework::DataLayout::kNCHW ? dims[1]
176+
: dims[dims.size() - 1];
177+
int HxW = x->numel() / N / C;
178+
179+
auto* x_d = x->data<T>();
180+
auto* dy_d = dy->data<T>();
181+
auto* scale_d = scale->data<T>();
182+
ConstEigenVectorArrayMap<T> scale_e(scale_d, C);
183+
184+
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
185+
T* dscale_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
186+
T* dbias_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
187+
EigenVectorArrayMap<T> dscale_e(dscale_d, C);
188+
EigenVectorArrayMap<T> dbias_e(dbias_d, C);
189+
190+
if (layout == framework::DataLayout::kNCHW) {
191+
// compute dx
192+
int stride = C * HxW;
193+
if (dx) {
194+
for (int i = 0; i < N; i++) {
195+
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
196+
EigenArrayMap<T> dx_e(dx_d, HxW, C);
197+
dx_e = dy_e.rowwise() * scale_e.transpose();
198+
dy_d += stride;
199+
dx_d += stride;
200+
}
201+
}
202+
// compute dscale and dbias
203+
if (dscale && dbias) {
204+
dy_d = dy->data<T>();
205+
for (int i = 0; i < N; i++) {
206+
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
207+
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
208+
if (i == 0) {
209+
dscale_e = (x_e * dy_e).colwise().sum();
210+
} else {
211+
dscale_e += (x_e * dy_e).colwise().sum();
212+
}
213+
if (i == 0) {
214+
dbias_e = dy_e.colwise().sum();
215+
} else {
216+
dbias_e += dy_e.colwise().sum();
217+
}
218+
x_d += stride;
219+
dy_d += stride;
220+
}
221+
}
222+
} else {
223+
int num = N * HxW;
224+
ConstEigenArrayMap<T> dy_e(dy_d, C, num);
225+
// compute dx
226+
if (dx) {
227+
EigenArrayMap<T> dx_e(dx_d, C, num);
228+
dx_e = dy_e.colwise() * scale_e;
229+
}
230+
// compute dscale and dbias
231+
if (dscale && dbias) {
232+
ConstEigenArrayMap<T> x_e(x_d, C, num);
233+
dscale_e = (x_e * dy_e).rowwise().sum();
234+
dbias_e = dy_e.rowwise().sum();
235+
}
236+
}
237+
}
238+
};
239+
240+
} // namespace operators
241+
} // namespace paddle
242+
243+
namespace ops = paddle::operators;
244+
using CPU = paddle::platform::CPUDeviceContext;
245+
246+
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
247+
ops::AffineChannelOpMaker,
248+
paddle::framework::DefaultGradOpDescMaker<true>);
249+
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad);
250+
251+
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
252+
ops::AffineChannelKernel<CPU, double>);
253+
REGISTER_OP_CPU_KERNEL(affine_channel_grad,
254+
ops::AffineChannelGradKernel<CPU, float>,
255+
ops::AffineChannelGradKernel<CPU, double>);

0 commit comments

Comments
 (0)