Skip to content

Commit 6e1d0ef

Browse files
authored
fix concat shape error (#25414) (#25438)
* fix concat shape error test=develop
1 parent 5c84eac commit 6e1d0ef

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape(
3131
auto out_dims = inputs_dims[0];
3232
size_t in_zero_dims_size = out_dims.size();
3333
for (size_t i = 1; i < n; i++) {
34+
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
35+
platform::errors::InvalidArgument(
36+
"The shape of input[0] and input[%d] "
37+
"is expected to be equal."
38+
"But received input[0]'s shape = "
39+
"[%s], input[%d]'s shape = [%s].",
40+
i, inputs_dims[0], i, inputs_dims[i]));
3441
for (size_t j = 0; j < in_zero_dims_size; j++) {
3542
if (j == axis) {
3643
if (is_runtime) {
3744
out_dims[axis] += inputs_dims[i][j];
3845
} else {
39-
if (inputs_dims[i][j] == -1) {
46+
if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
4047
out_dims[axis] = -1;
4148
} else {
4249
out_dims[axis] += inputs_dims[i][j];
@@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape(
5562
"[%s], input[%d]'s shape = [%s].",
5663
j, i, inputs_dims[0], i, inputs_dims[i]));
5764
}
65+
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
66+
out_dims[j] = inputs_dims[i][j];
67+
}
5868
}
5969
}
6070
}

0 commit comments

Comments
 (0)