Skip to content

Commit f8c764c

Browse files
committed
fix batch_norm and cos_sim infer shape, test=develop
1 parent a5ef6bf commit f8c764c

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,21 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
6565
(data_layout == DataLayout::kNCHW ? x_dims[1]
6666
: x_dims[x_dims.size() - 1]);
6767

68-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
69-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
70-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
71-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C);
68+
auto scale_dim = ctx->GetInputDim("Scale");
69+
auto bias_dim = ctx->GetInputDim("Bias");
7270

71+
bool check = true;
72+
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
73+
framework::product(bias_dim) <= 0)) {
74+
check = false;
75+
}
76+
77+
if (check) {
78+
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
79+
PADDLE_ENFORCE_EQ(scale_dim[0], C);
80+
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
81+
PADDLE_ENFORCE_EQ(scale_dim[0], C);
82+
}
7383
ctx->SetOutputDim("Y", x_dims);
7484
ctx->SetOutputDim("MeanOut", {C});
7585
ctx->SetOutputDim("VarianceOut", {C});

paddle/fluid/operators/cos_sim_op.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel {
4040
auto x_dims = ctx->GetInputDim("X");
4141
auto y_dims = ctx->GetInputDim("Y");
4242

43-
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
44-
"Ranks of Input(X) and Input(Y) must be equal.");
45-
PADDLE_ENFORCE_GE(x_dims.size(), 2,
46-
"Rank of Input(X) must not be less than 2.");
47-
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()),
48-
framework::slice_ddim(y_dims, 1, y_dims.size()),
49-
"All dimensions except the 1st of Input(X) and Input(Y) "
50-
"must be equal.");
51-
PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1,
52-
"The 1st dimension of Input(Y) must be equal to Input(X) or"
53-
" just 1 (which will be broadcasted to match Input(X)).");
43+
bool check = true;
44+
if ((!ctx->IsRuntime()) &&
45+
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
46+
check = false;
47+
}
48+
49+
if (check) {
50+
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
51+
"Ranks of Input(X) and Input(Y) must be equal.");
52+
PADDLE_ENFORCE_GE(x_dims.size(), 2,
53+
"Rank of Input(X) must not be less than 2.");
54+
PADDLE_ENFORCE_EQ(
55+
framework::slice_ddim(x_dims, 1, x_dims.size()),
56+
framework::slice_ddim(y_dims, 1, y_dims.size()),
57+
"All dimensions except the 1st of Input(X) and Input(Y) "
58+
"must be equal.");
59+
PADDLE_ENFORCE(
60+
x_dims[0] == y_dims[0] || y_dims[0] == 1,
61+
"The 1st dimension of Input(Y) must be equal to Input(X) or"
62+
" just 1 (which will be broadcasted to match Input(X)).");
63+
}
5464

5565
// resize tensor
5666
ctx->SetOutputDim("Out", {x_dims[0], 1});

0 commit comments

Comments
 (0)