File tree Expand file tree Collapse file tree 5 files changed +12
-23
lines changed Expand file tree Collapse file tree 5 files changed +12
-23
lines changed Original file line number Diff line number Diff line change @@ -80,14 +80,11 @@ class AffineChannelOp : public framework::OperatorWithKernel {
80
80
81
81
PADDLE_ENFORCE_EQ (scale_dims.size (), 1UL );
82
82
PADDLE_ENFORCE_EQ (b_dims.size (), 1UL );
83
- if (ctx->IsRuntime ()) {
83
+ if (ctx->IsRuntime () || scale_dims[ 0 ] > 0 ) {
84
84
PADDLE_ENFORCE_EQ (scale_dims[0 ], C);
85
+ }
86
+ if (ctx->IsRuntime () || b_dims[0 ] > 0 ) {
85
87
PADDLE_ENFORCE_EQ (b_dims[0 ], C);
86
- } else {
87
- if (scale_dims[0 ] > 0 && b_dims[0 ] > 0 ) {
88
- PADDLE_ENFORCE_EQ (scale_dims[0 ], C);
89
- PADDLE_ENFORCE_EQ (b_dims[0 ], C);
90
- }
91
88
}
92
89
93
90
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
Original file line number Diff line number Diff line change @@ -69,7 +69,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
69
69
std::vector<int64_t > output_shape ({in_dims[0 ], filter_dims[0 ]});
70
70
for (size_t i = 0 ; i < strides.size (); ++i) {
71
71
if ((!ctx->IsRuntime ()) &&
72
- (in_dims[i + 2 ] == - 1 || filter_dims[i + 2 ] == - 1 )) {
72
+ (in_dims[i + 2 ] <= 0 || filter_dims[i + 2 ] <= 0 )) {
73
73
output_shape.push_back (-1 );
74
74
} else {
75
75
output_shape.push_back (ConvOutputSize (in_dims[i + 2 ], filter_dims[i + 2 ],
Original file line number Diff line number Diff line change @@ -50,14 +50,12 @@ class ROIPoolOp : public framework::OperatorWithKernel {
50
50
int pooled_width = ctx->Attrs ().Get <int >(" pooled_width" );
51
51
float spatial_scale = ctx->Attrs ().Get <float >(" spatial_scale" );
52
52
53
- if (ctx->IsRuntime ()) {
54
- PADDLE_ENFORCE_GT (pooled_height, 0 ,
55
- " The pooled output height must greater than 0" );
56
- PADDLE_ENFORCE_GT (pooled_width, 0 ,
57
- " The pooled output width must greater than 0" );
58
- PADDLE_ENFORCE_GT (spatial_scale, 0 .0f ,
59
- " The spatial scale must greater than 0" );
60
- }
53
+ PADDLE_ENFORCE_GT (pooled_height, 0 ,
54
+ " The pooled output height must greater than 0" );
55
+ PADDLE_ENFORCE_GT (pooled_width, 0 ,
56
+ " The pooled output width must greater than 0" );
57
+ PADDLE_ENFORCE_GT (spatial_scale, 0 .0f ,
58
+ " The spatial scale must greater than 0" );
61
59
62
60
auto out_dims = input_dims;
63
61
out_dims[0 ] = rois_dims[0 ];
Original file line number Diff line number Diff line change @@ -41,16 +41,10 @@ class RowConvOp : public framework::OperatorWithKernel {
41
41
auto filter_dims = ctx->GetInputDim (" Filter" );
42
42
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
43
43
PADDLE_ENFORCE_EQ (filter_dims.size (), 2 , " Input(Y)'s rank should be 2." );
44
- if (ctx->IsRuntime ()) {
44
+ if (ctx->IsRuntime () || (x_dims[ 1 ] > 0 && filter_dims[ 1 ] > 0 ) ) {
45
45
PADDLE_ENFORCE_EQ (
46
46
x_dims[1 ], filter_dims[1 ],
47
47
" The 2nd dimension of Input(X) and Input(Filter) should be same." );
48
- } else {
49
- if (x_dims[1 ] > 0 && filter_dims[1 ] > 0 ) {
50
- PADDLE_ENFORCE_EQ (
51
- x_dims[1 ], filter_dims[1 ],
52
- " The 2nd dimension of Input(X) and Input(Filter) should be same." );
53
- }
54
48
}
55
49
56
50
ctx->SetOutputDim (" Out" , x_dims);
Original file line number Diff line number Diff line change @@ -102,7 +102,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
102
102
103
103
std::vector<int64_t > output_shape ({in_x_dims[0 ], in_x_dims[1 ]});
104
104
for (size_t i = 0 ; i < ksize.size (); ++i) {
105
- if (!ctx->IsRuntime () && in_x_dims[i + 2 ] == - 1 ) {
105
+ if (!ctx->IsRuntime () && in_x_dims[i + 2 ] <= 0 ) {
106
106
output_shape.push_back (-1 );
107
107
} else {
108
108
output_shape.push_back (UnpoolOutputSize (in_x_dims[i + 2 ], ksize[i],
You can’t perform that action at this time.
0 commit comments