@@ -36,7 +36,11 @@ class PadOp : public framework::OperatorWithKernel {
36
36
" of input tensor." );
37
37
std::vector<int64_t > out_dims (x_dim.size ());
38
38
for (int i = 0 ; i < x_dim.size (); ++i) {
39
- out_dims[i] = x_dim[i] + paddings[i * 2 ] + paddings[i * 2 + 1 ];
39
+ if ((!ctx->IsRuntime ()) && (x_dim[i] == -1 )) {
40
+ out_dims[i] = -1 ;
41
+ } else {
42
+ out_dims[i] = x_dim[i] + paddings[i * 2 ] + paddings[i * 2 + 1 ];
43
+ }
40
44
}
41
45
ctx->SetOutputDim (" Out" , framework::make_ddim (out_dims));
42
46
if (out_dims[0 ] == x_dim[0 ]) {
@@ -100,18 +104,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
100
104
using framework::OperatorWithKernel::OperatorWithKernel;
101
105
102
106
void InferShape (framework::InferShapeContext* ctx) const override {
103
- auto dout_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
104
- auto & paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
105
- for (int i = 0 ; i < dout_dims.size (); ++i) {
106
- dout_dims[i] -= (paddings[i * 2 ] + paddings[i * 2 + 1 ]);
107
- }
108
-
109
107
auto x_grad_name = framework::GradVarName (" X" );
110
108
if (ctx->HasOutput (x_grad_name)) {
111
109
auto dout_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
112
110
auto & paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
113
111
for (int i = 0 ; i < dout_dims.size (); ++i) {
114
- dout_dims[i] -= (paddings[i * 2 ] + paddings[i * 2 + 1 ]);
112
+ if (ctx->IsRuntime () || (dout_dims[i] != -1 )) {
113
+ dout_dims[i] -= (paddings[i * 2 ] + paddings[i * 2 + 1 ]);
114
+ }
115
115
}
116
116
ctx->SetOutputDim (x_grad_name, dout_dims);
117
117
}
0 commit comments