Skip to content

Commit 674aa06

Browse files
authored
fix concat mkldnn in extream condition. test=develop test=release/1.7 (#22877)
1 parent 143023b commit 674aa06

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
142142
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
143143

144144
ConcatPrimitiveFactory<T> prim_creator;
145+
// If one of the multiple inputs of concat has an input size of 0, the
146+
// actual size of the multi_input will change
145147
std::string key = platform::CreateKey(
146148
paddle::framework::vectorize<int>(multi_input[0]->dims()),
147-
ctx.OutputName("Out"), dt, platform::ThreadIDasStr());
149+
multi_input.size(), ctx.OutputName("Out"), dt,
150+
platform::ThreadIDasStr());
148151

149152
const std::string key_prim = key + "@concat_p";
150153
const std::string key_concat_pd = key + "@concat_pd";

0 commit comments

Comments
 (0)