@@ -15,6 +15,7 @@ limitations under the License. */
15
15
#include < memory>
16
16
#include " paddle/fluid/operators/concat_op.h"
17
17
#include " paddle/fluid/platform/mkldnn_helper.h"
18
+ #include " paddle/fluid/platform/mkldnn_reuse.h"
18
19
19
20
namespace paddle {
20
21
namespace operators {
@@ -38,15 +39,20 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
38
39
}
39
40
40
41
static memory::primitive_desc CreateMemPrimDesc (const Tensor& input,
41
- const mkldnn::engine& engine) {
42
- constexpr auto data_type = mkldnn:: memory::f32 ;
42
+ const mkldnn::engine& engine,
43
+ const memory::data_type& dt) {
43
44
const auto dims = paddle::framework::vectorize2int (input.dims ());
44
45
const auto format = input.format ();
45
- auto description = memory::desc (dims, data_type , format);
46
+ auto description = memory::desc (dims, dt , format);
46
47
auto mem_prim_desc = memory::primitive_desc (description, engine);
47
48
return mem_prim_desc;
48
49
}
49
50
51
+ static mkldnn::memory::format GetDstMemFormat (
52
+ const concat::primitive_desc& concat_pd) {
53
+ return (memory::format)concat_pd.dst_primitive_desc ().desc ().data .format ;
54
+ }
55
+
50
56
static platform::CPUPlace GetCpuPlace (
51
57
const paddle::framework::ExecutionContext& ctx) {
52
58
auto place = ctx.GetPlace ();
@@ -61,14 +67,30 @@ static const mkldnn::engine& GetMKLDNNEngine(
61
67
return dev_ctx.GetEngine ();
62
68
}
63
69
70
+ std::string CreateKey (const paddle::framework::ExecutionContext& ctx,
71
+ const std::vector<const Tensor*> multi_input,
72
+ const int64_t & concat_axis, const memory::data_type& dt) {
73
+ std::string key;
74
+ key.reserve (platform::MKLDNNHandler::MaxKeyLength);
75
+ for (size_t i = 0 ; i < multi_input.size (); i++) {
76
+ platform::MKLDNNHandler::AppendKeyDims (
77
+ &key, paddle::framework::vectorize2int (multi_input[i]->dims ()));
78
+ }
79
+ platform::MKLDNNHandler::AppendKey (&key, std::to_string (concat_axis));
80
+ platform::MKLDNNHandler::AppendKey (&key, ctx.op ().Output (" Out" ));
81
+ platform::MKLDNNHandler::AppendKey (&key, std::to_string (dt));
82
+ return key;
83
+ }
84
+
64
85
template <typename T>
65
86
class ConcatPrimitiveFactory {
66
87
public:
67
88
concat::primitive_desc CreateConcatPrimDescriptor (
68
89
const std::vector<const Tensor*> multi_input, Tensor* output,
69
- int concat_axis, const mkldnn::engine& mkldnn_engine) {
70
- CreateSourcesDescriptors (multi_input, mkldnn_engine);
71
- auto dst_desc = CreateDstMemDescriptor (output);
90
+ int concat_axis, const mkldnn::engine& mkldnn_engine,
91
+ const memory::data_type& dt = memory::data_type::f32 ) {
92
+ CreateSourcesDescriptors (multi_input, mkldnn_engine, dt);
93
+ auto dst_desc = CreateDstMemDescriptor (output, dt);
72
94
return concat::primitive_desc (dst_desc, concat_axis, srcs_pd);
73
95
}
74
96
@@ -79,23 +101,39 @@ class ConcatPrimitiveFactory {
79
101
return concat (concat_pd, inputs, dst_mem.get ());
80
102
}
81
103
104
+ void SetSrcDataHandleByIndex (const std::vector<memory>& srcs, const size_t & i,
105
+ void * handler) {
106
+ srcs[i].set_data_handle (handler);
107
+ }
108
+
109
+ void SetDstDataHandle (const memory& dst_mem, void * handler) {
110
+ dst_mem.set_data_handle (handler);
111
+ }
112
+
113
+ std::vector<memory> GetSrcs () { return srcs; }
114
+
115
+ memory GetDst () { return dst_mem.get (); }
116
+
82
117
private:
83
- memory::desc CreateDstMemDescriptor (Tensor* output) {
118
+ memory::desc CreateDstMemDescriptor (Tensor* output,
119
+ const memory::data_type& dt) {
84
120
auto dst_dims = paddle::framework::vectorize2int (output->dims ());
85
- return memory::desc (dst_dims, platform::MKLDNNGetDataType<T>(),
86
- memory::format::any);
121
+ return memory::desc (dst_dims, dt, memory::format::any);
87
122
}
88
123
89
124
mkldnn::memory CreateDstMemory (const concat::primitive_desc& concat_pd,
90
- Tensor* output, platform::CPUPlace place) {
125
+ Tensor* output,
126
+ const platform::CPUPlace& place) {
91
127
return memory (concat_pd.dst_primitive_desc (),
92
128
output->mutable_data <T>(place));
93
129
}
94
130
95
131
void CreateSourcesDescriptors (const std::vector<const Tensor*> multi_input,
96
- const mkldnn::engine& mkldnn_engine) {
132
+ const mkldnn::engine& mkldnn_engine,
133
+ const memory::data_type& dt) {
97
134
for (size_t i = 0 ; i < multi_input.size (); i++) {
98
- auto mem_prim_desc = CreateMemPrimDesc (*multi_input[i], mkldnn_engine);
135
+ auto mem_prim_desc =
136
+ CreateMemPrimDesc (*multi_input[i], mkldnn_engine, dt);
99
137
srcs_pd.push_back (mem_prim_desc);
100
138
srcs.push_back (
101
139
memory (mem_prim_desc, to_void_cast (multi_input[i]->data <T>())));
@@ -120,21 +158,59 @@ template <typename T>
120
158
class ConcatMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
121
159
public:
122
160
void Compute (const paddle::framework::ExecutionContext& ctx) const override {
123
- auto place = GetCpuPlace (ctx);
124
- const auto & mkldnn_engine = GetMKLDNNEngine (ctx);
125
-
126
161
auto multi_input = ctx.MultiInput <Tensor>(" X" );
127
162
EnforceLayouts (multi_input);
128
163
Tensor* output = ctx.Output <Tensor>(" Out" );
129
164
int64_t concat_axis = static_cast <int64_t >(ctx.Attr <int >(" axis" ));
165
+ auto & dev_ctx =
166
+ ctx.template device_context <paddle::platform::MKLDNNDeviceContext>();
167
+ auto place = GetCpuPlace (ctx);
168
+
169
+ memory::data_type dt =
170
+ paddle::framework::ToMKLDNNDataType (multi_input[0 ]->type ());
130
171
131
172
ConcatPrimitiveFactory<T> prim_creator;
132
- auto concat_pd = prim_creator.CreateConcatPrimDescriptor (
133
- multi_input, output, static_cast <int >(concat_axis), mkldnn_engine);
134
- auto concat = prim_creator.CreateConcatPrimitive (concat_pd, output, place);
135
- stream (stream::kind::eager).submit ({concat}).wait ();
173
+ std::string key = CreateKey (ctx, multi_input, concat_axis, dt);
174
+ const std::string key_prim = key + " @concat_p" ;
175
+ const std::string key_concat_pd = key + " @concat_pd" ;
176
+ const std::string key_srcs = key + " @concat_srcs" ;
177
+ const std::string key_dst = key + " @concat_dst" ;
178
+
179
+ std::shared_ptr<concat::primitive_desc> concat_pd;
180
+ std::shared_ptr<std::vector<memory>> srcs;
181
+ std::shared_ptr<memory> dst_mem;
182
+ auto concat_p = std::static_pointer_cast<concat>(dev_ctx.GetBlob (key_prim));
183
+
184
+ if (concat_p == nullptr ) {
185
+ const auto & mkldnn_engine = dev_ctx.GetEngine ();
186
+ concat_pd = std::make_shared<concat::primitive_desc>(
187
+ prim_creator.CreateConcatPrimDescriptor (multi_input, output,
188
+ static_cast <int >(concat_axis),
189
+ mkldnn_engine, dt));
190
+ concat_p = std::make_shared<concat>(
191
+ prim_creator.CreateConcatPrimitive (*concat_pd, output, place));
192
+ srcs = std::make_shared<std::vector<memory>>(prim_creator.GetSrcs ());
193
+ dst_mem = std::make_shared<memory>(prim_creator.GetDst ());
194
+ dev_ctx.SetBlob (key_prim, concat_p);
195
+ dev_ctx.SetBlob (key_concat_pd, concat_pd);
196
+ dev_ctx.SetBlob (key_srcs, srcs);
197
+ dev_ctx.SetBlob (key_dst, dst_mem);
198
+ } else {
199
+ srcs = std::static_pointer_cast<std::vector<memory>>(
200
+ dev_ctx.GetBlob (key_srcs));
201
+ dst_mem = std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_dst));
202
+ concat_pd = std::static_pointer_cast<concat::primitive_desc>(
203
+ dev_ctx.GetBlob (key_concat_pd));
204
+ for (size_t i = 0 ; i < multi_input.size (); i++) {
205
+ prim_creator.SetSrcDataHandleByIndex (
206
+ *srcs, i, to_void_cast<T>(multi_input[i]->data <T>()));
207
+ }
208
+ prim_creator.SetDstDataHandle (*dst_mem, output->mutable_data <T>(place));
209
+ }
210
+
211
+ stream (stream::kind::eager).submit ({*concat_p}).wait ();
136
212
137
- output->set_mkldnn_prim_desc (concat_pd. dst_primitive_desc ());
213
+ output->set_mkldnn_prim_desc (concat_pd-> dst_primitive_desc ());
138
214
}
139
215
};
140
216
} // namespace operators
@@ -143,4 +219,6 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
143
219
namespace ops = paddle::operators;
144
220
145
221
REGISTER_OP_KERNEL (concat, MKLDNN, ::paddle::platform::CPUPlace,
146
- ops::ConcatMKLDNNOpKernel<float >)
222
+ ops::ConcatMKLDNNOpKernel<float >,
223
+ ops::ConcatMKLDNNOpKernel<int8_t >,
224
+ ops::ConcatMKLDNNOpKernel<uint8_t >);
0 commit comments