Skip to content

Commit 9845b22

Browse files
authored
Merge pull request #16877 from heavengate/fix_infer_shape_pick
[cherry-pick] infer shape: grid_sampler, kldiv_loss, spectral_norm, interpolate
2 parents 2ff867f + 534beb5 commit 9845b22

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

paddle/fluid/operators/grid_sampler_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/grid_sampler_op.h"
16+
#include <memory>
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#ifdef PADDLE_WITH_CUDA
1819
#include "paddle/fluid/platform/cudnn_helper.h"
@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
4041
"Input(X) of GridSampleOp should be 4-D Tensor.");
4142
PADDLE_ENFORCE(grid_dims.size() == 4,
4243
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
43-
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
44-
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
45-
"Input(X) and Input(Grid) dims[0] should be equal.");
44+
if (ctx->IsRuntime() || grid_dims[3] > 0) {
45+
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
46+
}
4647
if (ctx->IsRuntime()) {
48+
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
49+
"Input(X) and Input(Grid) dims[0] should be equal.");
4750
PADDLE_ENFORCE_EQ(
4851
grid_dims[1], x_dims[2],
4952
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");

paddle/fluid/operators/interpolate_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
4040
int out_h = ctx->Attrs().Get<int>("out_h");
4141
int out_w = ctx->Attrs().Get<int>("out_w");
4242
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
43+
PADDLE_ENFORCE_GT(out_h, 0, "out_h should be greater than 0.");
44+
PADDLE_ENFORCE_GT(out_w, 0, "out_w should be greater than 0.");
4345

4446
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
4547
auto out_size_dim = ctx->GetInputDim("OutSize");
@@ -49,6 +51,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
4951
ctx->ShareLoD("X", "Out");
5052
return;
5153
}
54+
5255
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
5356
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
5457
}

paddle/fluid/operators/kldiv_loss_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class KLDivLossOp : public framework::OperatorWithKernel {
3535
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
3636
"Input(X) rank and Input(Target) rank should be same.");
3737
for (int i = 0; i < dim_x.size(); i++) {
38-
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
39-
"Input(X) and Input(Target) should in same shape.");
38+
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
39+
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
40+
"Input(X) and Input(Target) should in same shape.");
41+
}
4042
}
4143

4244
auto reduction = ctx->Attrs().Get<std::string>("reduction");

paddle/fluid/operators/spectral_norm_op.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel {
5656
}
5757
auto dim_u = ctx->GetInputDim("U");
5858
auto dim_v = ctx->GetInputDim("V");
59-
PADDLE_ENFORCE_EQ(dim_u[0], h,
60-
"Input(U) dims[0] should be equal to "
61-
"Input(Weight) dims[Attr(dim)]");
62-
PADDLE_ENFORCE_EQ(
63-
dim_v[0], w,
64-
"Input(V) dims[0] should be equal to "
65-
"the product of Input(Weight) dims except dims[Attr(dim)]");
59+
60+
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
61+
PADDLE_ENFORCE_EQ(dim_u[0], h,
62+
"Input(U) dims[0] should be equal to "
63+
"Input(Weight) dims[Attr(dim)]");
64+
}
65+
66+
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
67+
PADDLE_ENFORCE_EQ(
68+
dim_v[0], w,
69+
"Input(V) dims[0] should be equal to "
70+
"the product of Input(Weight) dims except dims[Attr(dim)]");
71+
}
6672

6773
ctx->SetOutputDim("Out", dim_weight);
6874
ctx->ShareLoD("Weight", /*->*/ "Out");

0 commit comments

Comments
 (0)