Skip to content

Commit 0cc6354

Browse files
committed
merge develop
2 parents b8f7fa9 + 4a497b8 commit 0cc6354

File tree

18 files changed

+886
-159
lines changed

18 files changed

+886
-159
lines changed

paddle/cuda/src/hl_top_k.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
244244
if (--beamSize == 0) break;
245245
__syncthreads();
246246

247-
// temporary solution
247+
// NOTE(zcd): temporary solution
248248
unsigned mask = 0u;
249249
CREATE_SHFL_MASK(mask, true);
250250

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "mkldnn.hpp"
16+
#include "paddle/fluid/operators/batch_norm_op.h"
17+
#include "paddle/fluid/platform/mkldnn_helper.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
using paddle::platform::MKLDNNDeviceContext;
24+
using paddle::platform::MKLDNNMemDesc;
25+
using mkldnn::memory;
26+
27+
template <typename T>
28+
using EigenArrayMap =
29+
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
30+
template <typename T>
31+
using ConstEigenArrayMap =
32+
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
33+
template <typename T>
34+
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
35+
template <typename T>
36+
using ConstEigenVectorArrayMap =
37+
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
38+
39+
namespace {
40+
template <typename T>
41+
struct bn_type_traits {
42+
using op_type = T;
43+
using op_desc = typename op_type::desc;
44+
using op_prim = typename op_type::primitive_desc;
45+
};
46+
47+
template <typename T, typename Container>
48+
void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
49+
Container *c) {
50+
auto it = std::begin(*c);
51+
52+
std::copy(scale_begin, scale_end, std::inserter(*c, it));
53+
std::copy(
54+
shift_begin, shift_end,
55+
std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end))));
56+
}
57+
58+
template <typename Op, typename... Args>
59+
void run_batch_norm_op(Args &&... args) {
60+
Op batch_norm_op{args...};
61+
62+
std::vector<mkldnn::primitive> pipeline;
63+
pipeline.push_back(batch_norm_op);
64+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
65+
}
66+
67+
template <typename T>
68+
inline void *cast_const_to_void(const T *t) {
69+
return static_cast<void *>(const_cast<T *>(t));
70+
}
71+
} // namespace
72+
73+
template <typename T>
74+
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
75+
public:
76+
void Compute(const framework::ExecutionContext &ctx) const override {
77+
auto data_layout_str = ctx.Attr<std::string>("data_layout");
78+
auto data_layout = framework::StringToDataLayout(data_layout_str);
79+
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
80+
"MKLDNN batch normalization handles only NCHW data layout");
81+
82+
const float epsilon = ctx.Attr<float>("epsilon");
83+
const float momentum = ctx.Attr<float>("momentum");
84+
const bool is_test = ctx.Attr<bool>("is_test");
85+
86+
const auto *x = ctx.Input<Tensor>("X");
87+
const auto *mean = ctx.Input<Tensor>("Mean");
88+
const auto *variance = ctx.Input<Tensor>("Variance");
89+
90+
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
91+
auto mkldnn_engine = dev_ctx.GetEngine();
92+
93+
auto *y = ctx.Output<Tensor>("Y");
94+
auto *mean_out = ctx.Output<Tensor>("MeanOut");
95+
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
96+
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
97+
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
98+
99+
const auto *scale = ctx.Input<Tensor>("Scale");
100+
const auto *shift = ctx.Input<Tensor>("Bias");
101+
102+
y->mutable_data<T>(ctx.GetPlace());
103+
mean_out->mutable_data<T>(ctx.GetPlace());
104+
variance_out->mutable_data<T>(ctx.GetPlace());
105+
106+
if (!is_test) {
107+
batch_mean->mutable_data<T>(ctx.GetPlace());
108+
batch_variance->mutable_data<T>(ctx.GetPlace());
109+
}
110+
111+
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
112+
: mkldnn::prop_kind::forward_training;
113+
114+
auto dims = paddle::framework::vectorize2int(x->dims());
115+
116+
auto src_md =
117+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
118+
auto dst_md =
119+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
120+
121+
auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
122+
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine};
123+
124+
auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())};
125+
auto dst = mkldnn::memory{dst_pd, y->data<T>()};
126+
127+
unsigned flags = mkldnn::use_scale_shift;
128+
if (is_test) flags |= mkldnn::use_global_stats;
129+
130+
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
131+
auto batch_norm_fwd_desc =
132+
bn_fwd_types::op_desc{propagation, src_md, epsilon, flags};
133+
auto batch_norm_fwd_pd =
134+
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
135+
136+
const unsigned int ic = dims[1];
137+
138+
// MKLDNN requires a single piece of memory for scale and shift/bias data
139+
const size_t scaleshift_size = 2 * ic;
140+
std::vector<T> scaleshift_data;
141+
scaleshift_data.reserve(scaleshift_size);
142+
143+
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
144+
shift->data<T>() + ic, &scaleshift_data);
145+
146+
auto scaleshift_memory = mkldnn::memory{
147+
batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()};
148+
149+
if (is_test) {
150+
auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
151+
cast_const_to_void(mean->data<T>())};
152+
153+
auto variance_memory =
154+
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
155+
cast_const_to_void(variance->data<T>())};
156+
157+
run_batch_norm_op<typename bn_fwd_types::op_type>(
158+
batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory,
159+
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
160+
dst);
161+
} else {
162+
auto mean_memory =
163+
mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
164+
cast_const_to_void(batch_mean->data<T>())};
165+
166+
auto variance_memory =
167+
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
168+
cast_const_to_void(batch_variance->data<T>())};
169+
170+
run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src,
171+
scaleshift_memory, dst,
172+
mean_memory, variance_memory);
173+
}
174+
175+
if (!is_test) {
176+
const unsigned int in = dims[0];
177+
const unsigned int sample_size = x->numel() / in / ic;
178+
179+
// saved_xx is use just in this batch of data
180+
EigenVectorArrayMap<T> saved_mean_e(
181+
batch_mean->mutable_data<T>(ctx.GetPlace()), ic);
182+
EigenVectorArrayMap<T> saved_variance_e(
183+
batch_variance->mutable_data<T>(ctx.GetPlace()), ic);
184+
saved_mean_e.setZero();
185+
saved_variance_e.setZero();
186+
187+
const unsigned int x_arr_size = in * ic;
188+
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size);
189+
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
190+
saved_mean_e(nc % ic) += x_arr.col(nc).sum();
191+
}
192+
saved_mean_e /= in * sample_size;
193+
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
194+
saved_variance_e(nc % ic) +=
195+
(x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm();
196+
}
197+
saved_variance_e /= in * sample_size;
198+
199+
ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic};
200+
ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic};
201+
202+
EigenVectorArrayMap<T> running_mean_arr(
203+
mean_out->mutable_data<T>(ctx.GetPlace()), ic);
204+
EigenVectorArrayMap<T> running_var_arr(
205+
variance_out->mutable_data<T>(ctx.GetPlace()), ic);
206+
207+
auto one_minus_momentum = 1. - momentum;
208+
running_mean_arr =
209+
mean_arr * momentum + saved_mean_e * one_minus_momentum;
210+
running_var_arr =
211+
variance_arr * momentum + saved_variance_e * one_minus_momentum;
212+
}
213+
}
214+
};
215+
216+
template <typename T>
217+
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
218+
public:
219+
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
220+
auto data_layout_str = ctx.Attr<std::string>("data_layout");
221+
auto data_layout = framework::StringToDataLayout(data_layout_str);
222+
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
223+
"MKLDNN batch normalization handles only NCHW data layout");
224+
225+
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
226+
auto mkldnn_engine = dev_ctx.GetEngine();
227+
228+
const float epsilon = ctx.Attr<float>("epsilon");
229+
230+
const auto *x = ctx.Input<Tensor>("X");
231+
const auto *scale = ctx.Input<Tensor>("Scale");
232+
const auto *shift = ctx.Input<Tensor>("Bias");
233+
const auto *batch_mean = ctx.Input<Tensor>("SavedMean");
234+
const auto *batch_variance = ctx.Input<Tensor>("SavedVariance");
235+
236+
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
237+
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
238+
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
239+
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
240+
241+
diff_x->mutable_data<T>(ctx.GetPlace());
242+
diff_scale->mutable_data<T>(ctx.GetPlace());
243+
diff_shift->mutable_data<T>(ctx.GetPlace());
244+
245+
auto dims = paddle::framework::vectorize2int(x->dims());
246+
unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats;
247+
248+
auto src_md =
249+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
250+
auto dst_md =
251+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
252+
auto diff_src_md =
253+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
254+
auto diff_dst_md =
255+
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
256+
257+
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
258+
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
259+
260+
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
261+
mkldnn::prop_kind::forward_training, src_md, epsilon, flags};
262+
auto batch_norm_fwd_pd =
263+
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
264+
265+
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
266+
mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags};
267+
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
268+
batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd};
269+
270+
auto src = mkldnn::memory{{src_md, mkldnn_engine},
271+
cast_const_to_void(x->data<T>())};
272+
273+
auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(),
274+
cast_const_to_void(batch_mean->data<T>())};
275+
276+
auto variance =
277+
mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(),
278+
cast_const_to_void(batch_variance->data<T>())};
279+
280+
auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine},
281+
cast_const_to_void(diff_y->data<T>())};
282+
283+
const unsigned int ic = dims[1];
284+
285+
const size_t scaleshift_size = 2 * ic;
286+
287+
std::vector<T> scaleshift_data;
288+
scaleshift_data.reserve(scaleshift_size);
289+
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
290+
shift->data<T>() + ic, &scaleshift_data);
291+
292+
auto scaleshift_memory = mkldnn::memory{
293+
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()};
294+
295+
std::vector<T> diff_scaleshift_data;
296+
diff_scaleshift_data.reserve(scaleshift_size);
297+
copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic,
298+
diff_shift->data<T>(), diff_shift->data<T>() + ic,
299+
&diff_scaleshift_data);
300+
301+
auto diff_scaleshift_memory =
302+
mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(),
303+
diff_scaleshift_data.data()};
304+
305+
auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine},
306+
static_cast<void *>(diff_x->data<T>())};
307+
308+
run_batch_norm_op<bn_bwd_types::op_type>(
309+
batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory,
310+
diff_src, diff_scaleshift_memory);
311+
312+
auto it = std::begin(diff_scaleshift_data);
313+
std::copy(it, std::next(it, ic), diff_scale->data<T>());
314+
std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
315+
diff_shift->data<T>());
316+
}
317+
};
318+
} // namespace operators
319+
} // namespace paddle
320+
321+
namespace ops = paddle::operators;
322+
REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace,
323+
ops::BatchNormMKLDNNOpKernel<float>);
324+
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace,
325+
ops::BatchNormMKLDNNGradOpKernel<float>);

0 commit comments

Comments
 (0)