Skip to content

Commit 70a967d

Browse files
committed
infer shape compatable -1. test=develop
1 parent cf5af3b commit 70a967d

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

paddle/fluid/operators/grid_sampler_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/grid_sampler_op.h"
16+
#include <memory>
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#ifdef PADDLE_WITH_CUDA
1819
#include "paddle/fluid/platform/cudnn_helper.h"
@@ -40,10 +41,12 @@ class GridSampleOp : public framework::OperatorWithKernel {
4041
"Input(X) of GridSampleOp should be 4-D Tensor.");
4142
PADDLE_ENFORCE(grid_dims.size() == 4,
4243
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
43-
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
44-
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
45-
"Input(X) and Input(Grid) dims[0] should be equal.");
44+
if (ctx->IsRuntime() || grid_dims[3] > 0) {
45+
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
46+
}
4647
if (ctx->IsRuntime()) {
48+
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
49+
"Input(X) and Input(Grid) dims[0] should be equal.");
4750
PADDLE_ENFORCE_EQ(
4851
grid_dims[1], x_dims[2],
4952
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");

paddle/fluid/operators/interpolate_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ class InterpolateOp : public framework::OperatorWithKernel {
4949
ctx->ShareLoD("X", "Out");
5050
return;
5151
}
52-
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
53-
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
52+
53+
if (ctx->IsRuntime() || (out_h > 0 && out_w > 0)) {
54+
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
55+
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
56+
} else {
57+
ctx->SetOutputDim("Out", dim_x);
58+
}
5459
}
5560

5661
protected:

paddle/fluid/operators/kldiv_loss_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class KLDivLossOp : public framework::OperatorWithKernel {
3535
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
3636
"Input(X) rank and Input(Target) rank should be same.");
3737
for (int i = 0; i < dim_x.size(); i++) {
38-
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
39-
"Input(X) and Input(Target) should in same shape.");
38+
if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) {
39+
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
40+
"Input(X) and Input(Target) should in same shape.");
41+
}
4042
}
4143

4244
auto reduction = ctx->Attrs().Get<std::string>("reduction");

paddle/fluid/operators/spectral_norm_op.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,19 @@ class SpectralNormOp : public framework::OperatorWithKernel {
5656
}
5757
auto dim_u = ctx->GetInputDim("U");
5858
auto dim_v = ctx->GetInputDim("V");
59-
PADDLE_ENFORCE_EQ(dim_u[0], h,
60-
"Input(U) dims[0] should be equal to "
61-
"Input(Weight) dims[Attr(dim)]");
62-
PADDLE_ENFORCE_EQ(
63-
dim_v[0], w,
64-
"Input(V) dims[0] should be equal to "
65-
"the product of Input(Weight) dims except dims[Attr(dim)]");
59+
60+
if (ctx->IsRuntime() || (dim_u[0] > 0 && h > 0)) {
61+
PADDLE_ENFORCE_EQ(dim_u[0], h,
62+
"Input(U) dims[0] should be equal to "
63+
"Input(Weight) dims[Attr(dim)]");
64+
}
65+
66+
if (ctx->IsRuntime() || (dim_v[0] > 0 && w > 0)) {
67+
PADDLE_ENFORCE_EQ(
68+
dim_v[0], w,
69+
"Input(V) dims[0] should be equal to "
70+
"the product of Input(Weight) dims except dims[Attr(dim)]");
71+
}
6672

6773
ctx->SetOutputDim("Out", dim_weight);
6874
ctx->ShareLoD("Weight", /*->*/ "Out");

0 commit comments

Comments
 (0)