@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
71
71
return reduced;
72
72
}
73
73
74
+ static const std::vector<int > GetDimsForKey (
75
+ const std::vector<const Tensor*>& inputs) {
76
+ auto dims_key = paddle::framework::vectorize<int >(inputs[0 ]->dims ());
77
+ for (auto it = std::next (inputs.begin ()); it != inputs.end (); ++it) {
78
+ dims_key.push_back ((*it)->dims ()[0 ]);
79
+ }
80
+ return dims_key;
81
+ }
82
+
74
83
template <typename T>
75
84
class ConcatPrimitiveFactory {
76
85
public:
@@ -134,6 +143,8 @@ template <typename T>
134
143
class ConcatMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
135
144
public:
136
145
void Compute (const paddle::framework::ExecutionContext& ctx) const override {
146
+ // If any of the multiple inputs of concat has an input size of 0, the
147
+ // actual size of the multi_input will change
137
148
auto multi_input = ReduceMultiInput (ctx.MultiInput <Tensor>(" X" ));
138
149
EnforceLayouts (multi_input);
139
150
Tensor* output = ctx.Output <Tensor>(" Out" );
@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
156
167
paddle::framework::ToMKLDNNDataType (multi_input[0 ]->type ());
157
168
158
169
ConcatPrimitiveFactory<T> prim_creator;
159
- // If one of the multiple inputs of concat has an input size of 0, the
160
- // actual size of the multi_input will change
161
- std::string key = platform::CreateKey (
162
- dev_ctx, paddle::framework::vectorize<int >(multi_input[0 ]->dims ()),
163
- multi_input.size (), ctx.OutputName (" Out" ), dt,
164
- platform::ThreadIDasStr ());
170
+ std::string key =
171
+ platform::CreateKey (dev_ctx, GetDimsForKey (multi_input),
172
+ multi_input.size (), ctx.OutputName (" Out" ), dt);
165
173
key = platform::ExtendKeyWithThreadInfoIfNeeded (dev_ctx, key);
166
174
167
175
const std::string key_prim = key + " @concat_p" ;
0 commit comments