Skip to content

Commit 2632327

Browse files
jczajatensor-tang
authored andcommitted
[MKL-DNN] Tensor modifications revert (#16462)
* Revert "[MKL-DNN] Fix to crash of Transformer when mkldnn is to be used (#16233)" This reverts commit 13816dd. Apart from enabling transformer for MKL-DNN * Revert "- MKL-DNN pooling updated to set_prim_desc" This reverts commit c63f6b2. Conflicts: paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc * Revert "[MKL-DNN] MKL-DNN specific Tensor modification (#15429)" test=develop This reverts commit dec9cf5. * - concat compilation fix - lint test=develop - Lint fixes test=develop - Lint fixes test=develop - Fix Transpose MKLDNN op test=develop
1 parent 4143a1c commit 2632327

17 files changed

+172
-274
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
134134
out_layout =
135135
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
136136

137+
auto& pool = platform::DeviceContextPool::Instance();
138+
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
139+
pool.Get(expected_kernel_type.place_));
140+
auto& cpu_engine = dev_ctx->GetEngine();
141+
137142
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
138143
std::vector<int> out_tz = in_tz;
139144

@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
142147
"Input tensor type is not supported: %s", in.type());
143148
memory::data_type out_type = in_type;
144149

150+
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
151+
auto out_format =
152+
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
153+
145154
// output tensor has the same dims as input. Reorder don't change dims
146155
out->Resize(in.dims());
147156

148-
// tempory mem pd fr out , to make reorder
149-
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
150-
paddle::framework::vectorize2int(out->dims()),
151-
mkldnn::memory::format::blocked, out_type);
152-
if (in.get_mkldnn_prim_desc() != out_mem_pd) {
157+
if (in_format != out_format) {
153158
void* in_data = GetDataFromTensor(in, in_type);
154159
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
155160

156-
auto in_memory = memory(in.get_mkldnn_prim_desc(), in_data);
157-
auto out_memory = memory(out_mem_pd, out_data);
161+
auto in_memory =
162+
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
163+
auto out_memory =
164+
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
158165

159166
platform::Reorder(in_memory, out_memory);
160167
} else {
161168
out->ShareDataWith(in);
162169
}
163170
out->set_layout(out_layout);
171+
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
172+
out->set_format(memory::format::format_undef);
164173
#endif
165174
}
166175

paddle/fluid/framework/data_transform.cc

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
5151
#ifdef PADDLE_WITH_MKLDNN
5252
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
5353
// Just set layout/format. No real transform occur
54+
55+
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
56+
ToMKLDNNFormat(lin));
57+
5458
out.ShareDataWith(input_tensor);
55-
// TODO(jczaja): Remove that once all mkldnn ops
56-
// are modified to work with mkldnn_blocked
57-
auto mkldnn_fmt = [&](int rank) {
58-
switch (rank) {
59-
case 5:
60-
return mkldnn::memory::format::ncdhw;
61-
case 4:
62-
return mkldnn::memory::format::nchw;
63-
case 3:
64-
return mkldnn::memory::format::ncw;
65-
case 2:
66-
return mkldnn::memory::format::nc;
67-
case 1:
68-
return mkldnn::memory::format::x;
69-
default:
70-
return mkldnn::memory::format::blocked;
71-
}
72-
};
73-
74-
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
75-
paddle::framework::vectorize2int(out.dims()),
76-
mkldnn_fmt(out.dims().size()));
77-
78-
out.set_mkldnn_prim_desc(out_mem_pd);
59+
out.set_layout(DataLayout::kMKLDNN);
60+
out.set_format(out_format);
7961
#endif
8062
} else {
8163
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel

paddle/fluid/framework/tensor.h

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <cstring>
1919
#include <memory>
2020
#include <typeindex>
21+
#include <utility>
2122
#include <vector>
2223
#include "paddle/fluid/framework/data_layout.h"
2324
#include "paddle/fluid/framework/ddim.h"
@@ -27,10 +28,6 @@ limitations under the License. */
2728
#include "paddle/fluid/platform/enforce.h"
2829
#include "paddle/fluid/platform/place.h"
2930

30-
#ifdef PADDLE_WITH_MKLDNN
31-
#include "paddle/fluid/platform/mkldnn_utils.h"
32-
#endif
33-
3431
namespace paddle {
3532

3633
namespace framework {
@@ -41,44 +38,23 @@ class Tensor {
4138
#ifdef PADDLE_WITH_MKLDNN
4239

4340
public:
44-
// TODO(jczaja): This is depracted and will be removed
45-
inline mkldnn::memory::format format() const {
46-
if (layout_ == DataLayout::kMKLDNN) {
47-
return static_cast<mkldnn::memory::format>(mem_pd_.desc().data.format);
48-
} else {
49-
return mkldnn::memory::format::format_undef;
50-
}
51-
}
41+
inline mkldnn::memory::format format() const { return format_; }
5242

53-
// TODO(jczaja): This is depracted and will be removed
54-
inline void set_format(
55-
const mkldnn::memory::format fmt,
56-
mkldnn::memory::data_type data_type = mkldnn::memory::f32) {
57-
mem_pd_ = paddle::platform::create_prim_desc_from_format(
58-
paddle::framework::vectorize2int(dims()), fmt, data_type);
59-
layout_ = DataLayout::kMKLDNN;
60-
}
61-
62-
inline mkldnn::memory::primitive_desc get_mkldnn_prim_desc() const {
63-
return mem_pd_;
64-
}
65-
66-
inline void set_mkldnn_prim_desc(
67-
const mkldnn::memory::primitive_desc& mem_pd) {
68-
// Internally MKL-DNN is just copying (increasing reference counter)
69-
// to shared_ptr. So asignment should be quite cheap
70-
mem_pd_ = mem_pd;
71-
layout_ = DataLayout::kMKLDNN;
43+
inline void set_format(const mkldnn::memory::format format) {
44+
format_ = format;
7245
}
7346

7447
protected:
7548
/**
7649
* @brief the detail format of memory block which have layout as kMKLDNN
7750
*
7851
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
79-
* nChw16c, etc. For a MKLDNN memory block, we store memory descriptor
52+
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
53+
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
54+
* this field.
8055
*/
81-
mutable mkldnn::memory::primitive_desc mem_pd_;
56+
57+
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
8258
#endif
8359

8460
public:

paddle/fluid/framework/tensor_util.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
4444
<< dst_place;
4545
return;
4646
}
47-
#ifdef PADDLE_WITH_MKLDNN
48-
if (src.layout() == DataLayout::kMKLDNN) {
49-
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
50-
}
51-
#endif
5247
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
5348
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
5449
}

paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
7777
} else {
7878
functor.RunMidWise(n, pre, post);
7979
}
80-
z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc());
80+
z->set_layout(DataLayout::kMKLDNN);
81+
z->set_format(x->format());
8182
} else {
8283
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
8384
x->format() != memory::format::format_undef,
@@ -115,8 +116,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
115116
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
116117

117118
// create mkldnn memory for dst
118-
auto dst_mem_pd = sum_pd.dst_primitive_desc();
119-
memory dst_memory = memory(dst_mem_pd, z_data);
119+
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
120120

121121
std::vector<primitive::at> inputs;
122122
inputs.push_back(srcs[0]);
@@ -129,7 +129,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
129129
pipeline.push_back(sum_prim);
130130
stream(stream::kind::eager).submit(pipeline).wait();
131131

132-
z->set_mkldnn_prim_desc(dst_mem_pd);
132+
z->set_layout(DataLayout::kMKLDNN);
133+
z->set_format(
134+
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
133135
}
134136
}
135137
};
@@ -150,19 +152,24 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
150152
auto* out = dout;
151153
auto *x = dout, *y = dout;
152154

155+
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
156+
in->set_layout(DataLayout::kMKLDNN);
157+
in->set_format(out->format());
158+
};
159+
153160
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
154161
if (dx->dims() == dy->dims()) {
155162
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
156163
if (dx) {
157164
blas.VCOPY(dout->numel(), dout->data<T>(),
158165
dx->mutable_data<T>(ctx.GetPlace()));
159-
dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
166+
set_mkldnn_format(dx, dout);
160167
}
161168

162169
if (dy) {
163170
blas.VCOPY(dout->numel(), dout->data<T>(),
164171
dy->mutable_data<T>(ctx.GetPlace()));
165-
dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
172+
set_mkldnn_format(dy, dout);
166173
}
167174
}
168175
} else {

paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
9696

9797
std::vector<int> src_tz = framework::vectorize2int(x->dims());
9898

99-
auto src_format = x->format();
99+
auto src_format =
100+
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
100101

101102
const std::string key = gethash(src_tz, algorithm);
102103
const std::string key_src_data =
@@ -126,8 +127,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
126127

127128
if (p_fwd == nullptr) {
128129
// create mkldnn memory for input X
130+
auto src_md = platform::MKLDNNMemDesc(
131+
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
129132
auto src_memory = std::shared_ptr<memory>(
130-
new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
133+
new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
131134
// save src_memory to be referred in backward path
132135
dev_ctx.SetBlob(key_src_mem, src_memory);
133136

@@ -174,7 +177,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
174177
pipeline.push_back(*p_fwd);
175178
stream(stream::kind::eager).submit(pipeline).wait();
176179

177-
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
180+
y->set_layout(DataLayout::kMKLDNN);
181+
y->set_format(GetMKLDNNFormat(*dst_memory));
178182
}
179183

180184
template <typename T>
@@ -192,6 +196,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
192196

193197
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
194198

199+
auto diff_y_format =
200+
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
201+
195202
const std::string key = gethash(diff_dst_tz, algorithm);
196203
const std::string key_src_data =
197204
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
@@ -203,8 +210,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
203210
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
204211
const std::string key_fwd_pd =
205212
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
206-
const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
207-
"-" + std::to_string(diff_y->format());
213+
const std::string key_with_layouts =
214+
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
208215
const std::string key_diff_src_mem =
209216
key_with_layouts + "@eltwise_diff_src_mem";
210217
const std::string key_diff_dst_mem =
@@ -227,8 +234,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
227234

228235
if (p_grad == nullptr) {
229236
// create mkldnn memory for input diff_y
237+
auto diff_dst_md = platform::MKLDNNMemDesc(
238+
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
230239
auto diff_dst_memory = std::shared_ptr<memory>(
231-
new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
240+
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
232241
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
233242

234243
// retrieve eltwise primitive desc from device context
@@ -272,7 +281,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
272281
pipeline.push_back(*p_grad);
273282
stream(stream::kind::eager).submit(pipeline).wait();
274283

275-
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
284+
diff_x->set_layout(DataLayout::kMKLDNN);
285+
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
276286
}
277287

278288
template <typename T, mkldnn::algorithm algorithm>

0 commit comments

Comments
 (0)