File tree Expand file tree Collapse file tree 5 files changed +31
-12
lines changed Expand file tree Collapse file tree 5 files changed +31
-12
lines changed Original file line number Diff line number Diff line change @@ -79,9 +79,13 @@ class AffineChannelOp : public framework::OperatorWithKernel {
79
79
: x_dims[x_dims.size () - 1 ]);
80
80
81
81
PADDLE_ENFORCE_EQ (scale_dims.size (), 1UL );
82
- PADDLE_ENFORCE_EQ (scale_dims[0 ], C);
83
82
PADDLE_ENFORCE_EQ (b_dims.size (), 1UL );
84
- PADDLE_ENFORCE_EQ (b_dims[0 ], C);
83
+ if (ctx->IsRuntime () || scale_dims[0 ] > 0 ) {
84
+ PADDLE_ENFORCE_EQ (scale_dims[0 ], C);
85
+ }
86
+ if (ctx->IsRuntime () || b_dims[0 ] > 0 ) {
87
+ PADDLE_ENFORCE_EQ (b_dims[0 ], C);
88
+ }
85
89
86
90
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
87
91
ctx->ShareLoD (" X" , " Out" );
Original file line number Diff line number Diff line change @@ -68,9 +68,14 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
68
68
69
69
std::vector<int64_t > output_shape ({in_dims[0 ], filter_dims[0 ]});
70
70
for (size_t i = 0 ; i < strides.size (); ++i) {
71
- output_shape.push_back (ConvOutputSize (in_dims[i + 2 ], filter_dims[i + 2 ],
72
- dilations[i], paddings[i],
73
- strides[i]));
71
+ if ((!ctx->IsRuntime ()) &&
72
+ (in_dims[i + 2 ] <= 0 || filter_dims[i + 2 ] <= 0 )) {
73
+ output_shape.push_back (-1 );
74
+ } else {
75
+ output_shape.push_back (ConvOutputSize (in_dims[i + 2 ], filter_dims[i + 2 ],
76
+ dilations[i], paddings[i],
77
+ strides[i]));
78
+ }
74
79
}
75
80
ctx->SetOutputDim (" Output" , framework::make_ddim (output_shape));
76
81
ctx->ShareLoD (" Input" , " Output" );
Original file line number Diff line number Diff line change @@ -51,8 +51,10 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
51
51
PADDLE_ENFORCE_EQ (label_dims.size (), 2 ,
52
52
" The rank of Input(Label) must be 2, "
53
53
" the shape is [N, 6]." );
54
- PADDLE_ENFORCE (label_dims[1 ] == 6 || label_dims[1 ] == 5 ,
55
- " The shape of Input(Label) is [N, 6] or [N, 5]." );
54
+ if (ctx->IsRuntime () || label_dims[1 ] > 0 ) {
55
+ PADDLE_ENFORCE (label_dims[1 ] == 6 || label_dims[1 ] == 5 ,
56
+ " The shape of Input(Label) is [N, 6] or [N, 5]." );
57
+ }
56
58
57
59
if (ctx->HasInput (" PosCount" )) {
58
60
PADDLE_ENFORCE (ctx->HasInput (" TruePos" ),
Original file line number Diff line number Diff line change @@ -41,9 +41,12 @@ class RowConvOp : public framework::OperatorWithKernel {
41
41
auto filter_dims = ctx->GetInputDim (" Filter" );
42
42
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
43
43
PADDLE_ENFORCE_EQ (filter_dims.size (), 2 , " Input(Y)'s rank should be 2." );
44
- PADDLE_ENFORCE_EQ (
45
- x_dims[1 ], filter_dims[1 ],
46
- " The 2nd dimension of Input(X) and Input(Filter) should be same." );
44
+ if (ctx->IsRuntime () || (x_dims[1 ] > 0 && filter_dims[1 ] > 0 )) {
45
+ PADDLE_ENFORCE_EQ (
46
+ x_dims[1 ], filter_dims[1 ],
47
+ " The 2nd dimension of Input(X) and Input(Filter) should be same." );
48
+ }
49
+
47
50
ctx->SetOutputDim (" Out" , x_dims);
48
51
ctx->ShareLoD (" X" , " Out" );
49
52
}
Original file line number Diff line number Diff line change @@ -99,10 +99,15 @@ class UnpoolOp : public framework::OperatorWithKernel {
99
99
PADDLE_ENFORCE (in_x_dims.size () == 4 ,
100
100
" Unpooling intput must be of 4-dimensional." );
101
101
PADDLE_ENFORCE_EQ (in_x_dims, in_y_dims);
102
+
102
103
std::vector<int64_t > output_shape ({in_x_dims[0 ], in_x_dims[1 ]});
103
104
for (size_t i = 0 ; i < ksize.size (); ++i) {
104
- output_shape.push_back (UnpoolOutputSize (in_x_dims[i + 2 ], ksize[i],
105
- paddings[i], strides[i]));
105
+ if (!ctx->IsRuntime () && in_x_dims[i + 2 ] <= 0 ) {
106
+ output_shape.push_back (-1 );
107
+ } else {
108
+ output_shape.push_back (UnpoolOutputSize (in_x_dims[i + 2 ], ksize[i],
109
+ paddings[i], strides[i]));
110
+ }
106
111
}
107
112
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
108
113
}
You can’t perform that action at this time.
0 commit comments