@@ -31,12 +31,19 @@ static inline framework::DDim ComputeAndCheckShape(
31
31
auto out_dims = inputs_dims[0 ];
32
32
size_t in_zero_dims_size = out_dims.size ();
33
33
for (size_t i = 1 ; i < n; i++) {
34
+ PADDLE_ENFORCE_EQ (inputs_dims[i].size (), out_dims.size (),
35
+ platform::errors::InvalidArgument (
36
+ " The shape of input[0] and input[%d] "
37
+ " is expected to be equal."
38
+ " But received input[0]'s shape = "
39
+ " [%s], input[%d]'s shape = [%s]." ,
40
+ i, inputs_dims[0 ], i, inputs_dims[i]));
34
41
for (size_t j = 0 ; j < in_zero_dims_size; j++) {
35
42
if (j == axis) {
36
43
if (is_runtime) {
37
44
out_dims[axis] += inputs_dims[i][j];
38
45
} else {
39
- if (inputs_dims[i][j] == -1 ) {
46
+ if (inputs_dims[i][j] == -1 || out_dims[j] == - 1 ) {
40
47
out_dims[axis] = -1 ;
41
48
} else {
42
49
out_dims[axis] += inputs_dims[i][j];
@@ -55,6 +62,9 @@ static inline framework::DDim ComputeAndCheckShape(
55
62
" [%s], input[%d]'s shape = [%s]." ,
56
63
j, i, inputs_dims[0 ], i, inputs_dims[i]));
57
64
}
65
+ if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0 ) {
66
+ out_dims[j] = inputs_dims[i][j];
67
+ }
58
68
}
59
69
}
60
70
}
0 commit comments