|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "mkldnn.hpp" |
| 16 | +#include "paddle/fluid/operators/batch_norm_op.h" |
| 17 | +#include "paddle/fluid/platform/mkldnn_helper.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +using Tensor = framework::Tensor; |
| 23 | +using paddle::platform::MKLDNNDeviceContext; |
| 24 | +using paddle::platform::MKLDNNMemDesc; |
| 25 | +using mkldnn::memory; |
| 26 | + |
| 27 | +template <typename T> |
| 28 | +using EigenArrayMap = |
| 29 | + Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>; |
| 30 | +template <typename T> |
| 31 | +using ConstEigenArrayMap = |
| 32 | + Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>; |
| 33 | +template <typename T> |
| 34 | +using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>; |
| 35 | +template <typename T> |
| 36 | +using ConstEigenVectorArrayMap = |
| 37 | + Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>; |
| 38 | + |
| 39 | +namespace { |
| 40 | +template <typename T> |
| 41 | +struct bn_type_traits { |
| 42 | + using op_type = T; |
| 43 | + using op_desc = typename op_type::desc; |
| 44 | + using op_prim = typename op_type::primitive_desc; |
| 45 | +}; |
| 46 | + |
| 47 | +template <typename T, typename Container> |
| 48 | +void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end, |
| 49 | + Container *c) { |
| 50 | + auto it = std::begin(*c); |
| 51 | + |
| 52 | + std::copy(scale_begin, scale_end, std::inserter(*c, it)); |
| 53 | + std::copy( |
| 54 | + shift_begin, shift_end, |
| 55 | + std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end)))); |
| 56 | +} |
| 57 | + |
| 58 | +template <typename Op, typename... Args> |
| 59 | +void run_batch_norm_op(Args &&... args) { |
| 60 | + Op batch_norm_op{args...}; |
| 61 | + |
| 62 | + std::vector<mkldnn::primitive> pipeline; |
| 63 | + pipeline.push_back(batch_norm_op); |
| 64 | + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); |
| 65 | +} |
| 66 | + |
| 67 | +template <typename T> |
| 68 | +inline void *cast_const_to_void(const T *t) { |
| 69 | + return static_cast<void *>(const_cast<T *>(t)); |
| 70 | +} |
| 71 | +} // namespace |
| 72 | + |
| 73 | +template <typename T> |
| 74 | +class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { |
| 75 | + public: |
| 76 | + void Compute(const framework::ExecutionContext &ctx) const override { |
| 77 | + auto data_layout_str = ctx.Attr<std::string>("data_layout"); |
| 78 | + auto data_layout = framework::StringToDataLayout(data_layout_str); |
| 79 | + PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW, |
| 80 | + "MKLDNN batch normalization handles only NCHW data layout"); |
| 81 | + |
| 82 | + const float epsilon = ctx.Attr<float>("epsilon"); |
| 83 | + const float momentum = ctx.Attr<float>("momentum"); |
| 84 | + const bool is_test = ctx.Attr<bool>("is_test"); |
| 85 | + |
| 86 | + const auto *x = ctx.Input<Tensor>("X"); |
| 87 | + const auto *mean = ctx.Input<Tensor>("Mean"); |
| 88 | + const auto *variance = ctx.Input<Tensor>("Variance"); |
| 89 | + |
| 90 | + auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); |
| 91 | + auto mkldnn_engine = dev_ctx.GetEngine(); |
| 92 | + |
| 93 | + auto *y = ctx.Output<Tensor>("Y"); |
| 94 | + auto *mean_out = ctx.Output<Tensor>("MeanOut"); |
| 95 | + auto *variance_out = ctx.Output<Tensor>("VarianceOut"); |
| 96 | + auto *batch_mean = ctx.Output<Tensor>("SavedMean"); |
| 97 | + auto *batch_variance = ctx.Output<Tensor>("SavedVariance"); |
| 98 | + |
| 99 | + const auto *scale = ctx.Input<Tensor>("Scale"); |
| 100 | + const auto *shift = ctx.Input<Tensor>("Bias"); |
| 101 | + |
| 102 | + y->mutable_data<T>(ctx.GetPlace()); |
| 103 | + mean_out->mutable_data<T>(ctx.GetPlace()); |
| 104 | + variance_out->mutable_data<T>(ctx.GetPlace()); |
| 105 | + |
| 106 | + if (!is_test) { |
| 107 | + batch_mean->mutable_data<T>(ctx.GetPlace()); |
| 108 | + batch_variance->mutable_data<T>(ctx.GetPlace()); |
| 109 | + } |
| 110 | + |
| 111 | + auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring |
| 112 | + : mkldnn::prop_kind::forward_training; |
| 113 | + |
| 114 | + auto dims = paddle::framework::vectorize2int(x->dims()); |
| 115 | + |
| 116 | + auto src_md = |
| 117 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 118 | + auto dst_md = |
| 119 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 120 | + |
| 121 | + auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; |
| 122 | + auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine}; |
| 123 | + |
| 124 | + auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())}; |
| 125 | + auto dst = mkldnn::memory{dst_pd, y->data<T>()}; |
| 126 | + |
| 127 | + unsigned flags = mkldnn::use_scale_shift; |
| 128 | + if (is_test) flags |= mkldnn::use_global_stats; |
| 129 | + |
| 130 | + using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; |
| 131 | + auto batch_norm_fwd_desc = |
| 132 | + bn_fwd_types::op_desc{propagation, src_md, epsilon, flags}; |
| 133 | + auto batch_norm_fwd_pd = |
| 134 | + bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; |
| 135 | + |
| 136 | + const unsigned int ic = dims[1]; |
| 137 | + |
| 138 | + // MKLDNN requires a single piece of memory for scale and shift/bias data |
| 139 | + const size_t scaleshift_size = 2 * ic; |
| 140 | + std::vector<T> scaleshift_data; |
| 141 | + scaleshift_data.reserve(scaleshift_size); |
| 142 | + |
| 143 | + copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), |
| 144 | + shift->data<T>() + ic, &scaleshift_data); |
| 145 | + |
| 146 | + auto scaleshift_memory = mkldnn::memory{ |
| 147 | + batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()}; |
| 148 | + |
| 149 | + if (is_test) { |
| 150 | + auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), |
| 151 | + cast_const_to_void(mean->data<T>())}; |
| 152 | + |
| 153 | + auto variance_memory = |
| 154 | + mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), |
| 155 | + cast_const_to_void(variance->data<T>())}; |
| 156 | + |
| 157 | + run_batch_norm_op<typename bn_fwd_types::op_type>( |
| 158 | + batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory, |
| 159 | + (const mkldnn::primitive::at &)variance_memory, scaleshift_memory, |
| 160 | + dst); |
| 161 | + } else { |
| 162 | + auto mean_memory = |
| 163 | + mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(), |
| 164 | + cast_const_to_void(batch_mean->data<T>())}; |
| 165 | + |
| 166 | + auto variance_memory = |
| 167 | + mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(), |
| 168 | + cast_const_to_void(batch_variance->data<T>())}; |
| 169 | + |
| 170 | + run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src, |
| 171 | + scaleshift_memory, dst, |
| 172 | + mean_memory, variance_memory); |
| 173 | + } |
| 174 | + |
| 175 | + if (!is_test) { |
| 176 | + const unsigned int in = dims[0]; |
| 177 | + const unsigned int sample_size = x->numel() / in / ic; |
| 178 | + |
| 179 | + // saved_xx is use just in this batch of data |
| 180 | + EigenVectorArrayMap<T> saved_mean_e( |
| 181 | + batch_mean->mutable_data<T>(ctx.GetPlace()), ic); |
| 182 | + EigenVectorArrayMap<T> saved_variance_e( |
| 183 | + batch_variance->mutable_data<T>(ctx.GetPlace()), ic); |
| 184 | + saved_mean_e.setZero(); |
| 185 | + saved_variance_e.setZero(); |
| 186 | + |
| 187 | + const unsigned int x_arr_size = in * ic; |
| 188 | + ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size); |
| 189 | + for (unsigned int nc = 0; nc < x_arr_size; ++nc) { |
| 190 | + saved_mean_e(nc % ic) += x_arr.col(nc).sum(); |
| 191 | + } |
| 192 | + saved_mean_e /= in * sample_size; |
| 193 | + for (unsigned int nc = 0; nc < x_arr_size; ++nc) { |
| 194 | + saved_variance_e(nc % ic) += |
| 195 | + (x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm(); |
| 196 | + } |
| 197 | + saved_variance_e /= in * sample_size; |
| 198 | + |
| 199 | + ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic}; |
| 200 | + ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic}; |
| 201 | + |
| 202 | + EigenVectorArrayMap<T> running_mean_arr( |
| 203 | + mean_out->mutable_data<T>(ctx.GetPlace()), ic); |
| 204 | + EigenVectorArrayMap<T> running_var_arr( |
| 205 | + variance_out->mutable_data<T>(ctx.GetPlace()), ic); |
| 206 | + |
| 207 | + auto one_minus_momentum = 1. - momentum; |
| 208 | + running_mean_arr = |
| 209 | + mean_arr * momentum + saved_mean_e * one_minus_momentum; |
| 210 | + running_var_arr = |
| 211 | + variance_arr * momentum + saved_variance_e * one_minus_momentum; |
| 212 | + } |
| 213 | + } |
| 214 | +}; |
| 215 | + |
| 216 | +template <typename T> |
| 217 | +class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { |
| 218 | + public: |
| 219 | + void Compute(const paddle::framework::ExecutionContext &ctx) const override { |
| 220 | + auto data_layout_str = ctx.Attr<std::string>("data_layout"); |
| 221 | + auto data_layout = framework::StringToDataLayout(data_layout_str); |
| 222 | + PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW, |
| 223 | + "MKLDNN batch normalization handles only NCHW data layout"); |
| 224 | + |
| 225 | + auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); |
| 226 | + auto mkldnn_engine = dev_ctx.GetEngine(); |
| 227 | + |
| 228 | + const float epsilon = ctx.Attr<float>("epsilon"); |
| 229 | + |
| 230 | + const auto *x = ctx.Input<Tensor>("X"); |
| 231 | + const auto *scale = ctx.Input<Tensor>("Scale"); |
| 232 | + const auto *shift = ctx.Input<Tensor>("Bias"); |
| 233 | + const auto *batch_mean = ctx.Input<Tensor>("SavedMean"); |
| 234 | + const auto *batch_variance = ctx.Input<Tensor>("SavedVariance"); |
| 235 | + |
| 236 | + const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Y")); |
| 237 | + auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 238 | + auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); |
| 239 | + auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); |
| 240 | + |
| 241 | + diff_x->mutable_data<T>(ctx.GetPlace()); |
| 242 | + diff_scale->mutable_data<T>(ctx.GetPlace()); |
| 243 | + diff_shift->mutable_data<T>(ctx.GetPlace()); |
| 244 | + |
| 245 | + auto dims = paddle::framework::vectorize2int(x->dims()); |
| 246 | + unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats; |
| 247 | + |
| 248 | + auto src_md = |
| 249 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 250 | + auto dst_md = |
| 251 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 252 | + auto diff_src_md = |
| 253 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 254 | + auto diff_dst_md = |
| 255 | + MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw); |
| 256 | + |
| 257 | + using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; |
| 258 | + using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; |
| 259 | + |
| 260 | + auto batch_norm_fwd_desc = bn_fwd_types::op_desc{ |
| 261 | + mkldnn::prop_kind::forward_training, src_md, epsilon, flags}; |
| 262 | + auto batch_norm_fwd_pd = |
| 263 | + bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine}; |
| 264 | + |
| 265 | + auto batch_norm_bwd_desc = bn_bwd_types::op_desc{ |
| 266 | + mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags}; |
| 267 | + auto batch_norm_bwd_pd = bn_bwd_types::op_prim{ |
| 268 | + batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd}; |
| 269 | + |
| 270 | + auto src = mkldnn::memory{{src_md, mkldnn_engine}, |
| 271 | + cast_const_to_void(x->data<T>())}; |
| 272 | + |
| 273 | + auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(), |
| 274 | + cast_const_to_void(batch_mean->data<T>())}; |
| 275 | + |
| 276 | + auto variance = |
| 277 | + mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(), |
| 278 | + cast_const_to_void(batch_variance->data<T>())}; |
| 279 | + |
| 280 | + auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine}, |
| 281 | + cast_const_to_void(diff_y->data<T>())}; |
| 282 | + |
| 283 | + const unsigned int ic = dims[1]; |
| 284 | + |
| 285 | + const size_t scaleshift_size = 2 * ic; |
| 286 | + |
| 287 | + std::vector<T> scaleshift_data; |
| 288 | + scaleshift_data.reserve(scaleshift_size); |
| 289 | + copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(), |
| 290 | + shift->data<T>() + ic, &scaleshift_data); |
| 291 | + |
| 292 | + auto scaleshift_memory = mkldnn::memory{ |
| 293 | + batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()}; |
| 294 | + |
| 295 | + std::vector<T> diff_scaleshift_data; |
| 296 | + diff_scaleshift_data.reserve(scaleshift_size); |
| 297 | + copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic, |
| 298 | + diff_shift->data<T>(), diff_shift->data<T>() + ic, |
| 299 | + &diff_scaleshift_data); |
| 300 | + |
| 301 | + auto diff_scaleshift_memory = |
| 302 | + mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(), |
| 303 | + diff_scaleshift_data.data()}; |
| 304 | + |
| 305 | + auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine}, |
| 306 | + static_cast<void *>(diff_x->data<T>())}; |
| 307 | + |
| 308 | + run_batch_norm_op<bn_bwd_types::op_type>( |
| 309 | + batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory, |
| 310 | + diff_src, diff_scaleshift_memory); |
| 311 | + |
| 312 | + auto it = std::begin(diff_scaleshift_data); |
| 313 | + std::copy(it, std::next(it, ic), diff_scale->data<T>()); |
| 314 | + std::copy(std::next(it, ic), std::end(diff_scaleshift_data), |
| 315 | + diff_shift->data<T>()); |
| 316 | + } |
| 317 | +}; |
| 318 | +} // namespace operators |
| 319 | +} // namespace paddle |
| 320 | + |
| 321 | +namespace ops = paddle::operators; |
| 322 | +REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace, |
| 323 | + ops::BatchNormMKLDNNOpKernel<float>); |
| 324 | +REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace, |
| 325 | + ops::BatchNormMKLDNNGradOpKernel<float>); |
0 commit comments