Skip to content

Commit 7beff76

Browse files
committed
merge #16907 cherry-pick
test=develop
1 parent 391649e commit 7beff76

File tree

7 files changed

+144
-46
lines changed

7 files changed

+144
-46
lines changed

paddle/fluid/operators/linear_chain_crf_op.cc

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/linear_chain_crf_op.h"
16+
1617
#include <memory>
1718

1819
namespace paddle {
@@ -152,21 +153,28 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
152153
auto transition_dims = ctx->GetInputDim("Transition");
153154
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
154155
"The Input(Transition) should be a 2-D tensor.");
155-
PADDLE_ENFORCE_EQ(
156-
transition_dims[0] - 2, transition_dims[1],
157-
"An invalid dimension for the Input(Transition), which should "
158-
"be a 2-D tensor with shape [(D + 2) x D].");
159-
PADDLE_ENFORCE_EQ(
160-
emission_dims[1], transition_dims[1],
156+
bool check = true;
157+
if ((!ctx->IsRuntime()) &&
158+
(transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
159+
check = false;
160+
}
161+
if (check) {
162+
PADDLE_ENFORCE_EQ(
163+
transition_dims[0] - 2, transition_dims[1],
164+
"An invalid dimension for the Input(Transition), which should "
165+
"be a 2-D tensor with shape [(D + 2) x D].");
166+
}
167+
PADDLE_INFERSHAPE_ENFORCE_EQ(
168+
ctx, emission_dims[1], transition_dims[1],
161169
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
162170
"should be equal to the tag number.");
163171

164172
auto label_dims = ctx->GetInputDim("Label");
165173
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
166174
"The Input(Label) should be a 2-D tensor with the 2nd "
167175
"dimensions fixed to 1.");
168-
PADDLE_ENFORCE_EQ(
169-
emission_dims[0], label_dims[0],
176+
PADDLE_INFERSHAPE_ENFORCE_EQ(
177+
ctx, emission_dims[0], label_dims[0],
170178
"The height of Input(Emission) and the height of Input(Label) "
171179
"should be the same.");
172180

@@ -211,21 +219,28 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
211219
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
212220
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
213221
"The Input(TransitionExps) should be a 2-D tensor.");
214-
PADDLE_ENFORCE_EQ(
215-
transition_exps_dims[0] - 2, transition_exps_dims[1],
216-
"An invalid dimension for the Input(TransitionExps), which should "
217-
"be a 2-D tensor with shape [(D + 2) x D].");
218-
PADDLE_ENFORCE_EQ(
219-
emission_exps_dims[1], transition_exps_dims[1],
222+
bool check = true;
223+
if ((!ctx->IsRuntime()) &&
224+
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
225+
check = false;
226+
}
227+
if (check) {
228+
PADDLE_ENFORCE_EQ(
229+
transition_exps_dims[0] - 2, transition_exps_dims[1],
230+
"An invalid dimension for the Input(TransitionExps), which should "
231+
"be a 2-D tensor with shape [(D + 2) x D].");
232+
}
233+
PADDLE_INFERSHAPE_ENFORCE_EQ(
234+
ctx, emission_exps_dims[1], transition_exps_dims[1],
220235
"The 2nd dimension of the Input(EmissionExps) and the "
221236
"Input(TransitionExps) should be equal to the tag number.");
222237

223238
auto label_dims = ctx->GetInputDim("Label");
224239
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
225240
"The Input(Label) should be a 2-D tensor with the 2nd "
226241
"dimensions fixed to 1.");
227-
PADDLE_ENFORCE_EQ(
228-
emission_exps_dims[0], label_dims[0],
242+
PADDLE_INFERSHAPE_ENFORCE_EQ(
243+
ctx, emission_exps_dims[0], label_dims[0],
229244
"The height of Input(EmissionExps) and the height of Input(Label) "
230245
"should be the same.");
231246

paddle/fluid/operators/metrics/accuracy_op.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
4141
// it's the output of topk.
4242

4343
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
44-
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
45-
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
46-
"the inference tensor's num_rows must be"
47-
" the same as label.");
44+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, label_dim[1], 1,
45+
"label's second dimension must be 1");
46+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, inference_dim[0], label_dim[0],
47+
"the inference tensor's num_rows must be"
48+
" the same as label.");
4849

4950
ctx->SetOutputDim("Accuracy", {1});
5051
ctx->SetOutputDim("Correct", {1});

paddle/fluid/operators/metrics/auc_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ class AucOp : public framework::OperatorWithKernel {
2828
PADDLE_ENFORCE(ctx->HasInput("Label"),
2929
"Input of Label should not be null.");
3030
auto predict_width = ctx->GetInputDim("Predict")[1];
31-
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
31+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_width, 2,
32+
"Only support binary classification");
3233
auto predict_height = ctx->GetInputDim("Predict")[0];
3334
auto label_height = ctx->GetInputDim("Label")[0];
3435

35-
PADDLE_ENFORCE_EQ(predict_height, label_height,
36-
"Out and Label should have same height.");
36+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_height, label_height,
37+
"Out and Label should have same height.");
3738

3839
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
3940
int slide_steps = ctx->Attrs().Get<int>("slide_steps");

paddle/fluid/operators/sample_logits_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
1514
#include "paddle/fluid/operators/sample_logits_op.h"
1615
#include <memory>
1716
#include "paddle/fluid/operators/math/sample_prob.h"
@@ -141,7 +140,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
141140
"The labels should be a 2-D tensor.");
142141

143142
const int num_samples = ctx->Attrs().Get<int>("num_samples");
144-
const int num_sampled_classes = labels_dims[1] + num_samples;
143+
int num_sampled_classes = labels_dims[1] + num_samples;
144+
if ((!ctx->IsRuntime()) && labels_dims[1] <= 0) {
145+
num_sampled_classes = -1;
146+
}
145147
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
146148
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
147149
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});

paddle/fluid/operators/smooth_l1_loss_op.cc

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/smooth_l1_loss_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -27,15 +28,39 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
2728

2829
auto x_dims = ctx->GetInputDim("X");
2930
auto y_dims = ctx->GetInputDim("Y");
30-
PADDLE_ENFORCE_EQ(x_dims, y_dims);
31+
bool check = true;
32+
if ((!ctx->IsRuntime()) &&
33+
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
34+
check = false;
35+
}
36+
if (check) {
37+
PADDLE_ENFORCE_EQ(x_dims, y_dims);
38+
}
3139
PADDLE_ENFORCE_GE(x_dims.size(), 2,
3240
"The tensor rank of Input(X) should not be less than 2.");
3341
if (ctx->HasInput("InsideWeight")) {
3442
PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"),
3543
"If weights are provided, must specify both "
3644
"inside and outside weights.");
37-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims);
38-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims);
45+
auto dims = ctx->GetInputDim("InsideWeight");
46+
bool check = true;
47+
if ((!ctx->IsRuntime()) &&
48+
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
49+
check = false;
50+
}
51+
if (check) {
52+
PADDLE_ENFORCE_EQ(dims, x_dims);
53+
}
54+
55+
dims = ctx->GetInputDim("OutsideWeight");
56+
check = true;
57+
if ((!ctx->IsRuntime()) &&
58+
(framework::product(dims) <= 0 || framework::product(x_dims) <= 0)) {
59+
check = false;
60+
}
61+
if (check) {
62+
PADDLE_ENFORCE_EQ(dims, x_dims);
63+
}
3964
}
4065

4166
ctx->SetOutputDim("Diff", x_dims);
@@ -110,11 +135,11 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
110135

111136
PADDLE_ENFORCE_GE(out_dims.size(), 2,
112137
"The tensor rank of Input(Out@Grad) should be 2.");
113-
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
114-
"The 1st dimension of Input(Out@Grad) must be "
115-
"same as input.");
116-
PADDLE_ENFORCE_EQ(out_dims[1], 1,
117-
"The 2nd dimension of Input(Out@Grad) must be 1.");
138+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], in_dims[0],
139+
"The 1st dimension of Input(Out@Grad) must be "
140+
"same as input.");
141+
PADDLE_INFERSHAPE_ENFORCE_EQ(
142+
ctx, out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1.");
118143

119144
auto x_grad_name = framework::GradVarName("X");
120145
auto y_grad_name = framework::GradVarName("Y");

paddle/fluid/operators/squared_l2_distance_op.cc

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,26 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
4141

4242
int rank = framework::arity(x_dims);
4343
PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
44-
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0],
45-
"Product of dimensions expcet the first dimension of "
46-
"input and target must be equal.");
47-
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
48-
"First dimension of target must be equal to input "
49-
"or to 1.");
50-
44+
bool check = true;
45+
if ((!ctx->IsRuntime()) &&
46+
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
47+
check = false;
48+
}
49+
if (check) {
50+
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0],
51+
product(y_dims) / y_dims[0],
52+
"Product of dimensions expcet the first dimension of "
53+
"input and target must be equal.");
54+
}
55+
check = true;
56+
if ((!ctx->IsRuntime()) && (y_dims[0] <= 0 || x_dims[0] <= 0)) {
57+
check = false;
58+
}
59+
if (check) {
60+
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
61+
"First dimension of target must be equal to input "
62+
"or to 1.");
63+
}
5164
ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]});
5265
ctx->SetOutputDim("Out", {x_dims[0], 1});
5366
ctx->ShareLoD("X", /*->*/ "Out");
@@ -91,12 +104,12 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
91104
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
92105
auto x_dims = ctx->GetInputDim("X");
93106
auto y_dims = ctx->GetInputDim("Y");
94-
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
95-
"First dimension of output gradient and "
96-
"input value must be equal.");
97-
PADDLE_ENFORCE_EQ(out_dims[1], 1,
98-
"Second dimension of output gradient "
99-
"must be 1.");
107+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], x_dims[0],
108+
"First dimension of output gradient and "
109+
"input value must be equal.");
110+
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[1], 1,
111+
"Second dimension of output gradient "
112+
"must be 1.");
100113
auto x_grad_name = framework::GradVarName("X");
101114
auto y_grad_name = framework::GradVarName("Y");
102115
if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);

paddle/fluid/platform/enforce.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,5 +356,46 @@ using CommonType2 = typename std::add_lvalue_reference<
356356
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
357357
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
358358

359+
#define __PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL1, __VAL2, __CMP, \
360+
__INV_CMP, ...) \
361+
do { \
362+
auto __val1 = (__VAL1); \
363+
auto __val2 = (__VAL2); \
364+
if (!__CTX->IsRuntime()) { \
365+
if (__val1 == -1 || __val2 == -1) { \
366+
break; \
367+
} \
368+
} \
369+
using __TYPE1__ = decltype(__val1); \
370+
using __TYPE2__ = decltype(__val2); \
371+
using __COMMON_TYPE1__ = \
372+
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
373+
using __COMMON_TYPE2__ = \
374+
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
375+
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
376+
static_cast<__COMMON_TYPE2__>(__val2)); \
377+
if (UNLIKELY(!__is_not_error)) { \
378+
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
379+
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \
380+
#__VAL1, #__VAL2, #__VAL1, \
381+
::paddle::string::to_string(__val1), #__VAL2, \
382+
::paddle::string::to_string(__val2), \
383+
::paddle::string::Sprintf(__VA_ARGS__)); \
384+
} \
385+
} while (0)
386+
387+
#define PADDLE_INFERSHAPE_ENFORCE_EQ(__CTX, __VAL0, __VAL1, ...) \
388+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, ==, !=, __VA_ARGS__)
389+
#define PADDLE_INFERSHAPE_ENFORCE_NE(__CTX, __VAL0, __VAL1, ...) \
390+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, !=, ==, __VA_ARGS__)
391+
#define PADDLE_INFERSHAPE_ENFORCE_GT(__CTX, __VAL0, __VAL1, ...) \
392+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >, <=, __VA_ARGS__)
393+
#define PADDLE_INFERSHAPE_ENFORCE_GE(__CTX, __VAL0, __VAL1, ...) \
394+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >=, <, __VA_ARGS__)
395+
#define PADDLE_INFERSHAPE_ENFORCE_LT(__CTX, __VAL0, __VAL1, ...) \
396+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <, >=, __VA_ARGS__)
397+
#define PADDLE_INFERSHAPE_ENFORCE_LE(__CTX, __VAL0, __VAL1, ...) \
398+
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <=, >, __VA_ARGS__)
399+
359400
} // namespace platform
360401
} // namespace paddle

0 commit comments

Comments
 (0)