@@ -29,6 +29,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
29
29
30
30
auto in_dims = ctx->GetInputDim (" Input" );
31
31
auto filter_dims = ctx->GetInputDim (" Filter" );
32
+ std::vector<int > output_size =
33
+ ctx->Attrs ().Get <std::vector<int >>(" output_size" );
32
34
std::vector<int > strides = ctx->Attrs ().Get <std::vector<int >>(" strides" );
33
35
std::vector<int > paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
34
36
std::vector<int > dilations = ctx->Attrs ().Get <std::vector<int >>(" dilations" );
@@ -42,6 +44,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
42
44
PADDLE_ENFORCE (in_dims.size () - strides.size () == 2U ,
43
45
" ConvTransposeOp input dimension and strides dimension should "
44
46
" be consistent." );
47
+ if (output_size.size ())
48
+ PADDLE_ENFORCE_EQ (output_size.size (), strides.size (),
49
+ " ConvTransposeOp output_size dimension and strides "
50
+ " dimension should be the same." );
45
51
PADDLE_ENFORCE_EQ (paddings.size (), strides.size (),
46
52
" ConvTransposeOp paddings dimension and strides "
47
53
" dimension should be the same." );
@@ -55,8 +61,17 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
55
61
std::vector<int64_t > output_shape ({in_dims[0 ], filter_dims[1 ] * groups});
56
62
for (size_t i = 0 ; i < strides.size (); ++i) {
57
63
auto filter_extent = dilations[i] * (filter_dims[i + 2 ] - 1 ) + 1 ;
58
- output_shape.push_back ((in_dims[i + 2 ] - 1 ) * strides[i] - 2 * paddings[i] +
59
- filter_extent);
64
+ auto infer_shape =
65
+ (in_dims[i + 2 ] - 1 ) * strides[i] - 2 * paddings[i] + filter_extent;
66
+ if (output_size.size ()) {
67
+ PADDLE_ENFORCE ((output_size[i] >= infer_shape &&
68
+ output_size[i] < infer_shape + strides[i]),
69
+ " ConvTransposeOp output_size should be "
70
+ " in appropriate range." );
71
+ output_shape.push_back (output_size[i]);
72
+ } else {
73
+ output_shape.push_back (infer_shape);
74
+ }
60
75
}
61
76
ctx->SetOutputDim (" Output" , framework::make_ddim (output_shape));
62
77
}
@@ -103,6 +118,10 @@ void Conv2DTransposeOpMaker::Make() {
103
118
AddOutput (" Output" ,
104
119
" (Tensor) The output tensor of convolution transpose operator. "
105
120
" The format of output tensor is also NCHW." );
121
+ AddAttr<std::vector<int >>(" output_size" ,
122
+ " (vector<int> default: []), the "
123
+ " size of the output tensor" )
124
+ .SetDefault ({});
106
125
AddAttr<int >(" groups" ,
107
126
" (int default:1), the groups number of the convolution "
108
127
" transpose operator. " )
@@ -192,7 +211,10 @@ void Conv3DTransposeOpMaker::Make() {
192
211
" Where N is batch size, C is "
193
212
" the number of channels, D is the depth of the feature, H is the "
194
213
" height of the feature, and W is the width of the feature." );
195
-
214
+ AddAttr<std::vector<int >>(" output_size" ,
215
+ " (vector<int> default: []), the "
216
+ " size of the output tensor" )
217
+ .SetDefault ({});
196
218
AddAttr<std::vector<int >>(
197
219
" dilations" ,
198
220
" (vector<int> default:{1, 1, 1}), the "
@@ -247,7 +269,7 @@ Parameters(strides, paddings) are three elements. These three elements represent
247
269
depth, height and width, respectively.
248
270
The input(X) size and output(Out) size may be different.
249
271
250
- Example:
272
+ Example:
251
273
Input:
252
274
Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
253
275
Filter shape: $(C_{in}, C_{out}, D_f, H_f, W_f)$
0 commit comments