Skip to content

Commit 7e33f5a

Browse files
authored
Merge pull request #16930 from phlrain/pick_infer_shape_many
Merge pull request #16840 from phlrain/fix_shape_check_many
2 parents 01aa670 + 8454038 commit 7e33f5a

File tree

9 files changed

+144
-58
lines changed

9 files changed

+144
-58
lines changed

paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
8181
public:
8282
void operator()(framework::InferShapeContext *context) const override {
8383
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
84-
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
85-
"The number of element of subscript index must be 1");
84+
if (context->IsRuntime()) {
85+
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
86+
"The number of element of subscript index must be 1");
87+
}
8688
if (!context->HasInput("X")) {
8789
return;
8890
}

paddle/fluid/operators/data_norm_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/data_norm_op.h"
16+
#include <memory>
1617
#include <string>
1718
#include "paddle/fluid/framework/data_layout.h"
1819
#ifdef PADDLE_WITH_MKLDNN
@@ -65,9 +66,11 @@ class DataNormOp : public framework::OperatorWithKernel {
6566
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL);
6667
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL);
6768
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL);
68-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
69-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
70-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
69+
if (ctx->IsRuntime()) {
70+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C);
71+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C);
72+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
73+
}
7174

7275
ctx->SetOutputDim("Y", x_dims);
7376
ctx->SetOutputDim("Means", {C});

paddle/fluid/operators/huber_loss_op.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ class HuberLossOp : public framework::OperatorWithKernel {
2828
auto x_dims = ctx->GetInputDim("X");
2929
auto y_dims = ctx->GetInputDim("Y");
3030

31-
PADDLE_ENFORCE_EQ(x_dims, y_dims);
3231
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
3332
"The rank of Input(X) must be 2 and the shape is "
3433
"[batch_size, 1].");
35-
PADDLE_ENFORCE_EQ(x_dims[1], 1,
36-
"Each row of Input(X) contains a real value, "
37-
"so the 2nd dimension of Input(X) must be 1.");
34+
if (ctx->IsRuntime() ||
35+
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
36+
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same");
37+
}
38+
if (ctx->IsRuntime()) {
39+
PADDLE_ENFORCE_EQ(x_dims[1], 1,
40+
"Each row of Input(X) contains a real value, "
41+
"so the 2nd dimension of Input(X) must be 1.");
42+
}
3843

3944
ctx->SetOutputDim("Residual", x_dims);
4045
ctx->SetOutputDim("Out", {x_dims[0], 1});

paddle/fluid/operators/layer_norm_op.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ class LayerNormOp : public framework::OperatorWithKernel {
4646
int right = static_cast<int>(matrix_dim[1]);
4747
if (ctx->HasInput("Scale")) {
4848
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1);
49-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
49+
50+
if (ctx->IsRuntime()) {
51+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right,
52+
"scale should with right");
53+
}
5054
}
5155
if (ctx->HasInput("Bias")) {
5256
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1);
53-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
57+
if (ctx->IsRuntime()) {
58+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right,
59+
"bias should with right");
60+
}
5461
}
5562

5663
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));

paddle/fluid/operators/metrics/precision_recall_op.cc

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,40 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
4040
auto max_probs_dims = ctx->GetInputDim("MaxProbs");
4141
auto labels_dims = ctx->GetInputDim("Labels");
4242

43-
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
44-
"Each instance contains one max probability, so the "
45-
"shape of Input(MaxProbs) should be [batch_size, 1].");
46-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Indices"), max_probs_dims,
47-
"The shape of Input(Indices) should be [batch_size, 1].");
48-
PADDLE_ENFORCE_EQ(max_probs_dims[0], labels_dims[0],
49-
"The 1st dimension of Input(MaxProbs) and "
50-
"Input(Labels) both are batch_size and the shape should "
51-
"be the same.");
52-
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
53-
"The 2nd dimension of Input(Labels) contains instance "
54-
"label and the shape should be equal to 1.");
43+
if (ctx->IsRuntime()) {
44+
PADDLE_ENFORCE_EQ(max_probs_dims[1], 1,
45+
"Each instance contains one max probability, so the "
46+
"shape of Input(MaxProbs) should be [batch_size, 1].");
47+
PADDLE_ENFORCE_EQ(
48+
ctx->GetInputDim("Indices"), max_probs_dims,
49+
"The shape of Input(Indices) should bes same with max_probs_dims");
50+
PADDLE_ENFORCE_EQ(
51+
max_probs_dims[0], labels_dims[0],
52+
"The 1st dimension of Input(MaxProbs) and "
53+
"Input(Labels) both are batch_size and the shape should "
54+
"be the same.");
55+
PADDLE_ENFORCE_EQ(labels_dims[1], 1,
56+
"The 2nd dimension of Input(Labels) contains instance "
57+
"label and the shape should be equal to 1.");
58+
}
5559
if (ctx->HasInput("Weights")) {
5660
auto weights_dims = ctx->GetInputDim("Weights");
57-
PADDLE_ENFORCE_EQ(weights_dims,
58-
framework::make_ddim({max_probs_dims[0], 1}),
59-
"The shape of Input(Weights) should be "
60-
"[batch_size, 1].");
61+
62+
if (ctx->IsRuntime()) {
63+
PADDLE_ENFORCE_EQ(weights_dims,
64+
framework::make_ddim({max_probs_dims[0], 1}),
65+
"The shape of Input(Weights) should be "
66+
"[batch_size, 1].");
67+
}
6168
}
6269
if (ctx->HasInput("StatesInfo")) {
6370
auto states_dims = ctx->GetInputDim("StatesInfo");
64-
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
65-
"The shape of Input(StatesInfo) should be "
66-
"[class_number, 4].");
71+
72+
if (ctx->IsRuntime()) {
73+
PADDLE_ENFORCE_EQ(states_dims, framework::make_ddim({cls_num, 4}),
74+
"The shape of Input(StatesInfo) should be "
75+
"[class_number, 4].");
76+
}
6777
}
6878

6979
// Layouts of BatchMetrics and AccumMetrics both are:

paddle/fluid/operators/minus_op.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/minus_op.h"
1616

17+
#include <memory>
1718
#include <string>
1819
#include <vector>
1920

@@ -38,9 +39,12 @@ class MinusOp : public framework::OperatorWithKernel {
3839
auto x_dims = ctx->GetInputDim("X");
3940
auto y_dims = ctx->GetInputDim("Y");
4041

41-
PADDLE_ENFORCE_EQ(
42-
x_dims, y_dims,
43-
"Minus operator must take two tensor with same num of elements");
42+
if (ctx->IsRuntime() ||
43+
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
44+
PADDLE_ENFORCE_EQ(
45+
x_dims, y_dims,
46+
"Minus operator must take two tensor with same num of elements");
47+
}
4448
ctx->SetOutputDim("Out", x_dims);
4549
ctx->ShareLoD("X", /*->*/ "Out");
4650
}

paddle/fluid/operators/modified_huber_loss_op.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,16 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
2828
auto x_dims = ctx->GetInputDim("X");
2929
auto y_dims = ctx->GetInputDim("Y");
3030

31-
PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same.");
3231
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2.");
33-
PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
32+
if (ctx->IsRuntime() ||
33+
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
34+
PADDLE_ENFORCE_EQ(x_dims, y_dims,
35+
"The shape of X and Y must be the same.");
36+
}
37+
38+
if (ctx->IsRuntime()) {
39+
PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
40+
}
3441

3542
ctx->SetOutputDim("IntermediateVal", x_dims);
3643
ctx->SetOutputDim("Out", {x_dims[0], 1});
@@ -90,11 +97,13 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
9097
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
9198
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
9299

93-
PADDLE_ENFORCE_EQ(
94-
intermediate_dims, x_dims,
95-
"The shape of X and intermediate value must be the same.");
96-
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
97-
"The shape of Input(Out@Grad) and X must be the same.");
100+
if (ctx->IsRuntime()) {
101+
PADDLE_ENFORCE_EQ(
102+
intermediate_dims, x_dims,
103+
"The shape of X and intermediate value must be the same.");
104+
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
105+
"The shape of Input(Out@Grad) and X must be the same.");
106+
}
98107

99108
if (ctx->HasOutput(framework::GradVarName("X"))) {
100109
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);

paddle/fluid/operators/space_to_depth_op.cc

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,44 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
3434
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
3535

3636
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
37-
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
38-
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
39-
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
40-
41-
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
42-
"input channel should be divisible of the square of "
43-
"SpaceToDepthOp blocksize");
44-
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
45-
"input Height should be divisible of the square of "
46-
"SpaceToDepthOp blocksize");
47-
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
48-
"input Width should be divisible of the square of "
49-
"SpaceToDepthOp blocksize");
37+
if (ctx->IsRuntime()) {
38+
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
39+
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
40+
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
41+
42+
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
43+
"input channel should be divisible of the square of "
44+
"SpaceToDepthOp blocksize");
45+
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
46+
"input Height should be divisible of the square of "
47+
"SpaceToDepthOp blocksize");
48+
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
49+
"input Width should be divisible of the square of "
50+
"SpaceToDepthOp blocksize");
51+
} else {
52+
if (x_dims[1] != -1) {
53+
PADDLE_ENFORCE_GT(x_dims[1], 0,
54+
"input channel should be Greater than 0");
55+
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
56+
"input channel should be divisible of the square of "
57+
"SpaceToDepthOp blocksize");
58+
}
59+
if (x_dims[2] != -1) {
60+
PADDLE_ENFORCE_GT(x_dims[2], 0,
61+
"input Height should be Greater than 0");
62+
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
63+
"input Height should be divisible of the square of "
64+
"SpaceToDepthOp blocksize");
65+
}
66+
67+
if (x_dims[3] != -1) {
68+
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
69+
70+
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
71+
"input Width should be divisible of the square of "
72+
"SpaceToDepthOp blocksize");
73+
}
74+
}
5075

5176
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
5277
<< "Attribute blocksize" << blocksize << std::endl;

paddle/fluid/operators/tree_conv_op.cc

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,38 @@ class TreeConvOp : public framework::OperatorWithKernel {
6262
auto edge_dims = ctx->GetInputDim("EdgeSet");
6363
auto vector_dims = ctx->GetInputDim("NodesVector");
6464
auto filter_dims = ctx->GetInputDim("Filter");
65-
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
65+
66+
if (ctx->IsRuntime()) {
67+
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
68+
} else {
69+
if (edge_dims[2] != -1) {
70+
PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2");
71+
}
72+
}
6673
PADDLE_ENFORCE_EQ(edge_dims.size(), 3,
6774
"The dimension of EdgeSet Tensor should be 3");
6875
PADDLE_ENFORCE_EQ(vector_dims.size(), 3,
6976
"The dimension of NodesVector Tensor should be 3");
7077
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
7178
"The dimension of Filter Tensor should be 4");
72-
PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3");
73-
PADDLE_ENFORCE_EQ(
74-
filter_dims[0], vector_dims[2],
75-
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
79+
80+
if (ctx->IsRuntime()) {
81+
PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3");
82+
PADDLE_ENFORCE_EQ(
83+
filter_dims[0], vector_dims[2],
84+
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
85+
} else {
86+
if (filter_dims[1] != -1) {
87+
PADDLE_ENFORCE_EQ(filter_dims[1], 3,
88+
"Input(Filter) dim[1] should be 3");
89+
}
90+
91+
if (filter_dims[0] != -1 && vector_dims[2] != -1) {
92+
PADDLE_ENFORCE_EQ(
93+
filter_dims[0], vector_dims[2],
94+
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]");
95+
}
96+
}
7697
auto output_dims = framework::make_ddim(
7798
{vector_dims[0], vector_dims[1], filter_dims[2], filter_dims[3]});
7899
ctx->SetOutputDim("Out", output_dims);

0 commit comments

Comments
 (0)