Skip to content

Commit d44d173

Browse files
author
Wojciech Uss
authored
fix cache key in concat oneDNN kernel (#31820) (#31837)
* fix cache key in concat oneDNN kernel * key simplified
1 parent aa731e6 commit d44d173

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
7171
return reduced;
7272
}
7373

74+
static const std::vector<int> GetDimsForKey(
75+
const std::vector<const Tensor*>& inputs) {
76+
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
77+
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
78+
dims_key.push_back((*it)->dims()[0]);
79+
}
80+
return dims_key;
81+
}
82+
7483
template <typename T>
7584
class ConcatPrimitiveFactory {
7685
public:
@@ -134,6 +143,8 @@ template <typename T>
134143
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
135144
public:
136145
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
146+
// If any of the multiple inputs of concat has an input size of 0, the
147+
// actual size of the multi_input will change
137148
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
138149
EnforceLayouts(multi_input);
139150
Tensor* output = ctx.Output<Tensor>("Out");
@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
156167
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
157168

158169
ConcatPrimitiveFactory<T> prim_creator;
159-
// If one of the multiple inputs of concat has an input size of 0, the
160-
// actual size of the multi_input will change
161-
std::string key = platform::CreateKey(
162-
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
163-
multi_input.size(), ctx.OutputName("Out"), dt,
164-
platform::ThreadIDasStr());
170+
std::string key =
171+
platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
172+
multi_input.size(), ctx.OutputName("Out"), dt);
165173
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
166174

167175
const std::string key_prim = key + "@concat_p";

0 commit comments

Comments
 (0)