|
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +/*Licensed under the Apache License, Version 2.0(the "License"); |
| 16 | + you may not use this file except in compliance with the License. |
| 17 | + You may obtain a copy of the License at |
| 18 | +
|
| 19 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 20 | +
|
| 21 | + Unless required by applicable law or agreed to in writing, software |
| 22 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 23 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 24 | + See the License for the specific language governing permissions and |
| 25 | + limitations under the License. */ |
| 26 | + |
| 27 | +#include "mkldnn.hpp" |
| 28 | +#include "paddle/fluid/framework/tensor.h" |
| 29 | +#include "paddle/fluid/operators/math/selected_rows_functor.h" |
| 30 | +#include "paddle/fluid/operators/sum_op.h" |
| 31 | +#include "paddle/fluid/platform/device_context.h" |
| 32 | +#include "paddle/fluid/platform/mkldnn_helper.h" |
| 33 | + |
| 34 | +namespace paddle { |
| 35 | +namespace operators { |
| 36 | + |
| 37 | +using paddle::framework::Tensor; |
| 38 | +using paddle::platform::MKLDNNDeviceContext; |
| 39 | +using paddle::platform::CPUDeviceContext; |
| 40 | +using framework::DataLayout; |
| 41 | +using mkldnn::memory; |
| 42 | +using mkldnn::primitive; |
| 43 | +using mkldnn::stream; |
| 44 | +using mkldnn::sum; |
| 45 | +using mkldnn::reorder; |
| 46 | +using platform::to_void_cast; |
| 47 | + |
| 48 | +template <typename T> |
| 49 | +class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { |
| 50 | + public: |
| 51 | + void Compute(const paddle::framework::ExecutionContext& ctx) const override { |
| 52 | + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), |
| 53 | + "It must use CPUPlace."); |
| 54 | + auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); |
| 55 | + const auto& mkldnn_engine = dev_ctx.GetEngine(); |
| 56 | + auto in_vars = ctx.MultiInputVar("X"); |
| 57 | + |
| 58 | + const int N = in_vars.size(); |
| 59 | + auto out_var = ctx.OutputVar("Out"); |
| 60 | + bool in_place = out_var == in_vars[0]; |
| 61 | + |
| 62 | + if (out_var->IsType<framework::LoDTensor>()) { |
| 63 | + LoDTensor* output = ctx.Output<LoDTensor>("Out"); |
| 64 | + T* output_data = output->mutable_data<T>(ctx.GetPlace()); |
| 65 | + |
| 66 | + std::vector<int> dst_tz = framework::vectorize2int(output->dims()); |
| 67 | + auto src_tz = dst_tz; |
| 68 | + memory::format output_format{memory::format::format_undef}; |
| 69 | + std::vector<float> scales; |
| 70 | + std::vector<memory::primitive_desc> srcs_mpd; |
| 71 | + std::vector<mkldnn::memory> srcs_mem; |
| 72 | + |
| 73 | + PADDLE_ENFORCE(in_vars[0]->IsType<LoDTensor>(), |
| 74 | + "Input[0] must be LoDTensors"); |
| 75 | + auto& input0 = in_vars[0]->Get<LoDTensor>(); |
| 76 | + PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN && |
| 77 | + input0.format() != memory::format::format_undef, |
| 78 | + "Wrong layout/format for inputs[0]"); |
| 79 | + |
| 80 | + memory::format input_format = input0.format(); |
| 81 | + |
| 82 | + if (src_tz.size() == 1 && (input_format == memory::format::nchw || |
| 83 | + input_format == memory::format::nhwc)) { |
| 84 | + input_format = memory::format::x; |
| 85 | + } |
| 86 | + if (src_tz.size() == 2 && (input_format == memory::format::nchw || |
| 87 | + input_format == memory::format::nhwc)) { |
| 88 | + input_format = memory::format::nc; |
| 89 | + } |
| 90 | + |
| 91 | + for (int i = in_place ? 1 : 0; i < N; i++) { |
| 92 | + PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(), |
| 93 | + "all inputs must be all LoDTensors"); |
| 94 | + auto& input = in_vars[i]->Get<LoDTensor>(); |
| 95 | + PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN && |
| 96 | + input.format() != memory::format::format_undef, |
| 97 | + "Wrong layout/format for inputs"); |
| 98 | + |
| 99 | + if (input.numel() == 0) { |
| 100 | + continue; |
| 101 | + } |
| 102 | + |
| 103 | + const T* input_data = input.data<T>(); |
| 104 | + |
| 105 | + auto src_md = |
| 106 | + memory::desc(src_tz, memory::data_type::f32, input_format); |
| 107 | + auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine); |
| 108 | + auto src_mem = memory(src_mpd, to_void_cast(input_data)); |
| 109 | + srcs_mpd.push_back(src_mpd); |
| 110 | + srcs_mem.push_back(src_mem); |
| 111 | + scales.push_back(1.0); |
| 112 | + } |
| 113 | + |
| 114 | + auto dst_md = |
| 115 | + memory::desc(dst_tz, memory::data_type::f32, memory::format::any); |
| 116 | + |
| 117 | + auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); |
| 118 | + |
| 119 | + std::shared_ptr<memory> dst_mem; |
| 120 | + if (in_place) { |
| 121 | + dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); |
| 122 | + } else { |
| 123 | + dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); |
| 124 | + } |
| 125 | + std::vector<mkldnn::primitive::at> inputs; |
| 126 | + for (size_t i = 0; i < srcs_mem.size(); ++i) { |
| 127 | + inputs.push_back(srcs_mem[i]); |
| 128 | + } |
| 129 | + |
| 130 | + auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); |
| 131 | + output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd); |
| 132 | + |
| 133 | + primitive reorder_prim; |
| 134 | + std::shared_ptr<memory> target_mem; |
| 135 | + if (in_place) { |
| 136 | + output_format = input_format; |
| 137 | + target_mem.reset(new memory( |
| 138 | + {{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine}, |
| 139 | + output_data)); |
| 140 | + reorder_prim = reorder(*dst_mem, *target_mem); |
| 141 | + } |
| 142 | + |
| 143 | + std::vector<primitive> pipeline; |
| 144 | + pipeline.push_back(sum_prim); |
| 145 | + if (in_place) pipeline.push_back(reorder_prim); |
| 146 | + stream(stream::kind::eager).submit(pipeline).wait(); |
| 147 | + |
| 148 | + output->set_layout(DataLayout::kMKLDNN); |
| 149 | + output->set_format(output_format); |
| 150 | + } else if (out_var->IsType<framework::SelectedRows>()) { |
| 151 | + // TODO(@mozga-intel) Add MKLDNN SelectedRows support |
| 152 | + std::unique_ptr<framework::SelectedRows> in0; |
| 153 | + if (in_place) { |
| 154 | + // If is in_place, we store the input[0] to in0 |
| 155 | + auto& in_sel0 = in_vars[0]->Get<SelectedRows>(); |
| 156 | + auto& rows = in_sel0.rows(); |
| 157 | + in0.reset(new framework::SelectedRows(rows, in_sel0.height())); |
| 158 | + in0->mutable_value()->ShareDataWith(in_sel0.value()); |
| 159 | + } |
| 160 | + |
| 161 | + auto get_selected_row = [&](size_t i) -> const SelectedRows& { |
| 162 | + if (i == 0 && in0) { |
| 163 | + return *in0.get(); |
| 164 | + } else { |
| 165 | + return in_vars[i]->Get<SelectedRows>(); |
| 166 | + } |
| 167 | + }; |
| 168 | + auto* out = ctx.Output<SelectedRows>("Out"); |
| 169 | + out->mutable_rows()->clear(); |
| 170 | + auto* out_value = out->mutable_value(); |
| 171 | + |
| 172 | + // Runtime InferShape |
| 173 | + size_t first_dim = 0; |
| 174 | + for (int i = 0; i < N; i++) { |
| 175 | + auto& sel_row = get_selected_row(i); |
| 176 | + first_dim += sel_row.rows().size(); |
| 177 | + } |
| 178 | + auto in_dim = |
| 179 | + framework::vectorize(get_selected_row(N - 1).value().dims()); |
| 180 | + in_dim[0] = static_cast<int64_t>(first_dim); |
| 181 | + |
| 182 | + out_value->Resize(framework::make_ddim(in_dim)); |
| 183 | + |
| 184 | + // if all the input sparse vars are empty, no need to |
| 185 | + // merge these vars. |
| 186 | + if (first_dim == 0UL) { |
| 187 | + return; |
| 188 | + } |
| 189 | + out_value->mutable_data<T>(ctx.GetPlace()); |
| 190 | + math::SelectedRowsAddTo<CPUDeviceContext, T> functor; |
| 191 | + int64_t offset = 0; |
| 192 | + for (int i = 0; i < N; i++) { |
| 193 | + auto& sel_row = get_selected_row(i); |
| 194 | + if (sel_row.rows().size() == 0) { |
| 195 | + continue; |
| 196 | + } |
| 197 | + PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); |
| 198 | + functor(ctx.template device_context<CPUDeviceContext>(), sel_row, |
| 199 | + offset, out); |
| 200 | + offset += sel_row.value().numel(); |
| 201 | + } |
| 202 | + } else if (out_var->IsType<framework::LoDTensorArray>()) { |
| 203 | + // TODO(@mozga-intel) Add MKLDNN LoDTensorArray support |
| 204 | + auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>(); |
| 205 | + for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { |
| 206 | + PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(), |
| 207 | + "Only support all inputs are TensorArray"); |
| 208 | + auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>(); |
| 209 | + |
| 210 | + for (size_t i = 0; i < in_array.size(); ++i) { |
| 211 | + if (in_array[i].numel() != 0) { |
| 212 | + if (i >= out_array.size()) { |
| 213 | + out_array.resize(i + 1); |
| 214 | + } |
| 215 | + if (out_array[i].numel() == 0) { |
| 216 | + framework::TensorCopy(in_array[i], in_array[i].place(), |
| 217 | + ctx.device_context(), &out_array[i]); |
| 218 | + out_array[i].set_lod(in_array[i].lod()); |
| 219 | + } else { |
| 220 | + PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod()); |
| 221 | + auto in = EigenVector<T>::Flatten(in_array[i]); |
| 222 | + auto result = EigenVector<T>::Flatten(out_array[i]); |
| 223 | + result.device(*ctx.template device_context<MKLDNNDeviceContext>() |
| 224 | + .eigen_device()) = result + in; |
| 225 | + } |
| 226 | + } |
| 227 | + } |
| 228 | + } |
| 229 | + } else { |
| 230 | + PADDLE_THROW("Unexpected branch, output variable type is %s", |
| 231 | + out_var->Type().name()); |
| 232 | + } |
| 233 | + } |
| 234 | +}; |
| 235 | + |
| 236 | +} // namespace operators |
| 237 | +} // namespace paddle |
| 238 | + |
| 239 | +REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace, |
| 240 | + paddle::operators::SumMKLDNNOpKernel<float>); |
0 commit comments