Skip to content

Commit 01aa670

Browse files
authored
Merge pull request #16933 from phlrain/pick_many_infer_fix
Pick many infer fix
2 parents 591f087 + dc202c2 commit 01aa670

9 files changed

+98
-56
lines changed

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
6464

6565
auto c_dims = ctx->GetInputDim("C0");
6666
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
67-
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
67+
if (ctx->IsRuntime()) {
68+
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
69+
}
70+
6871
if (ctx->HasInput("H0")) {
6972
auto h_dims = ctx->GetInputDim("H0");
70-
PADDLE_ENFORCE(h_dims == c_dims,
71-
"The dimension of Input(H0) and Input(C0) "
72-
"should be the same.");
73+
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
74+
if (ctx->IsRuntime() ||
75+
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
76+
PADDLE_ENFORCE(h_dims == c_dims,
77+
"The dimension of Input(H0) and Input(C0) "
78+
"should be the same.");
79+
}
7380
}
7481

7582
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
@@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
7986
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
8087
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
8188
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
89+
8290
if (ctx->HasInput("AttentionBias")) {
8391
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
8492
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,

paddle/fluid/operators/bpr_loss_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ class BprLossOp : public framework::OperatorWithKernel {
3232
int rank = x_dims.size();
3333
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
3434
"Input(X) and Input(Label) shall have the same rank.");
35-
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
36-
framework::slice_ddim(label_dims, 0, rank - 1),
37-
"Input(X) and Input(Label) shall have the same shape "
38-
"except the last dimension.");
35+
36+
if (ctx->IsRuntime() || (framework::product(x_dims) > 0 &&
37+
framework::product(label_dims) > 0)) {
38+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
39+
framework::slice_ddim(label_dims, 0, rank - 1),
40+
"Input(X) and Input(Label) shall have the same shape "
41+
"except the last dimension.");
42+
}
3943

4044
auto y_dims = x_dims;
4145
y_dims[rank - 1] = 1;

paddle/fluid/operators/conv_shift_op.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
3636
auto y_dims = ctx->GetInputDim("Y");
3737
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
3838
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
39-
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
40-
"The 1st dimension of Input(X) and Input(Y) should "
41-
"be equal.");
42-
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
43-
"The 2nd dimension of Input(Y) should be odd.");
44-
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
45-
"The 2nd dimension of Input(Y) should be less than or "
46-
"equal to the 2nd dimension of Input(X).");
39+
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
40+
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
41+
"The 1st dimension of Input(X) and Input(Y) should "
42+
"be equal.");
43+
if (ctx->IsRuntime() || y_dims[1] > 0)
44+
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
45+
"The 2nd dimension of Input(Y) should be odd.");
46+
if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
47+
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
48+
"The 2nd dimension of Input(Y) should be less than or "
49+
"equal to the 2nd dimension of Input(X).");
4750
ctx->ShareDim("X", /*->*/ "Out");
4851
ctx->ShareLoD("X", /*->*/ "Out");
4952
}

paddle/fluid/operators/merge_lod_tensor_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase {
164164

165165
auto mask_dim = context->GetInputDim("Mask");
166166
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
167-
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
167+
if (context->IsRuntime() || mask_dim[1] > 0) {
168+
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
169+
}
168170

169171
context->SetOutputDim("Out", context->GetInputDim("InTrue"));
170172
}

paddle/fluid/operators/positive_negative_pair_op.cc

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
6161
auto query_dim = ctx->GetInputDim("QueryID");
6262
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
6363
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
64-
PADDLE_ENFORCE_EQ(
65-
label_dim[0], score_dim[0],
66-
"Tensor Score and Label should have the same height (batch size).");
67-
PADDLE_ENFORCE_EQ(label_dim[1], 1,
68-
"The width of Label should be 1, i.e. each item should "
69-
"have a scalar label.");
70-
PADDLE_ENFORCE(query_dim == label_dim,
71-
"QueryID should have the same shape as Label.");
72-
if (ctx->HasInput("Weight")) {
73-
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
74-
"Weight should have the same shape as Label.");
64+
65+
if (ctx->IsRuntime() ||
66+
(score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
67+
PADDLE_ENFORCE_EQ(
68+
label_dim[0], score_dim[0],
69+
"Tensor Score and Label should have the same height (batch size).");
70+
71+
PADDLE_ENFORCE_EQ(label_dim[1], 1,
72+
"The width of Label should be 1, i.e. each item should "
73+
"have a scalar label.");
74+
75+
PADDLE_ENFORCE(query_dim == label_dim,
76+
"QueryID should have the same shape as Label.");
77+
78+
if (ctx->HasInput("Weight")) {
79+
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
80+
"Weight should have the same shape as Label.");
81+
}
82+
83+
int column = ctx->Attrs().Get<int>("column");
84+
auto depth = score_dim[1];
85+
PADDLE_ENFORCE(column < depth && column >= -depth,
86+
"Attribute column should be in the range of [-%l, %l)",
87+
depth, depth);
7588
}
76-
int column = ctx->Attrs().Get<int>("column");
77-
auto depth = score_dim[1];
78-
PADDLE_ENFORCE(column < depth && column >= -depth,
79-
"Attribute column should be in the range of [-%l, %l)",
80-
depth, depth);
8189

8290
ctx->SetOutputDim("PositivePair", scalar_dim);
8391
ctx->SetOutputDim("NegativePair", scalar_dim);

paddle/fluid/operators/scatter_op.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class ScatterOp : public framework::OperatorWithKernel {
4242
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
4343
ctx->GetInputDim("Ids")[0],
4444
"Updates and Ids should have same batch-size.");
45-
framework::DDim data_dim(updates_dims);
46-
for (int i = 1; i < data_dim.size(); ++i) {
47-
PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]);
48-
}
4945
ctx->SetOutputDim("Out", ref_dims);
5046
}
5147

paddle/fluid/operators/split_lod_tensor_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ class SplitLoDTensorInferShape : public framework::InferShapeBase {
157157

158158
auto mask_dim = context->GetInputDim("Mask");
159159
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
160-
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
160+
if (context->IsRuntime()) {
161+
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
162+
}
161163

162164
context->SetOutputDim("OutTrue", context->GetInputDim("X"));
163165
context->SetOutputDim("OutFalse", context->GetInputDim("X"));

paddle/fluid/operators/sum_op.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,21 @@ class SumOp : public framework::OperatorWithKernel {
6565
if (framework::product(in_dim) == 0) {
6666
in_dim = x_dim;
6767
} else {
68-
PADDLE_ENFORCE_EQ(in_dim, x_dim, "Input tensors must have same shape");
68+
if (ctx->IsRuntime()) {
69+
PADDLE_ENFORCE_EQ(in_dim, x_dim,
70+
"Input tensors must have same shape");
71+
} else {
72+
PADDLE_ENFORCE_EQ(in_dim.size(), x_dim.size(),
73+
"Input tensors must have same shape size");
74+
// if in_dim or x_dim has -1, not check equal
75+
for (int i = 0; i < x_dim.size(); ++i) {
76+
if (x_dim[i] == -1 || in_dim[i] == -1) {
77+
continue;
78+
}
79+
PADDLE_ENFORCE_EQ(in_dim[i], x_dim[i],
80+
"Input tensors must have same shape if not -1");
81+
}
82+
}
6983
}
7084
}
7185
ctx->SetOutputDim("Out", in_dim);

paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
3434
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
3535
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
3636
"Input(Label)'s rank should be 2.");
37-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
38-
"The 1st dimension of Input(X) and Input(Label) should "
39-
"be equal.");
40-
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
41-
"The 2nd dimension of "
42-
"Input(Label) should be 1.");
37+
if (ctx->IsRuntime()) {
38+
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
39+
"The 1st dimension of Input(X) and Input(Label) should "
40+
"be equal.");
41+
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
42+
"The 2nd dimension of "
43+
"Input(Label) should be 1.");
44+
}
4345
ctx->SetOutputDim("Y", {x_dims[0], 1});
4446
ctx->ShareLoD("X", /*->*/ "Y");
4547
}
@@ -74,17 +76,20 @@ class TeacherStudentSigmoidLossGradientOp
7476
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
7577
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
7678
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
77-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
78-
"The 1st dimension of Input(X) and Input(Label) should "
79-
"be equal.");
80-
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
81-
"The 1st dimension of Input(X) and Input(Y@Grad) should "
82-
"be equal.");
83-
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
84-
"The 2nd dimension of Input(Y@Grad) should be 1.");
85-
PADDLE_ENFORCE_EQ(label_dims[1], 1,
86-
"When Attr(soft_label) == false, the 2nd dimension of "
87-
"Input(Label) should be 1.");
79+
if (ctx->IsRuntime()) {
80+
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
81+
"The 1st dimension of Input(X) and Input(Label) should "
82+
"be equal.");
83+
PADDLE_ENFORCE_EQ(
84+
x_dims[0], dy_dims[0],
85+
"The 1st dimension of Input(X) and Input(Y@Grad) should "
86+
"be equal.");
87+
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
88+
"The 2nd dimension of Input(Y@Grad) should be 1.");
89+
PADDLE_ENFORCE_EQ(label_dims[1], 1,
90+
"When Attr(soft_label) == false, the 2nd dimension of "
91+
"Input(Label) should be 1.");
92+
}
8893
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
8994
ctx->ShareLoD("X", framework::GradVarName("X"));
9095
}

0 commit comments

Comments
 (0)