@@ -27,11 +27,24 @@ using Tensor = framework::Tensor;
27
27
28
28
// Base convolution operator definations for other conv
29
29
// like operators to reuse the implementation.
30
- inline int OutputSize (int input_size, int filter_size, int padding,
31
- int stride) {
32
- int output_size = (input_size - filter_size + 2 * padding) / stride + 1 ;
30
+ inline int OutputSize (int input_size, int filter_size, int dilation,
31
+ int padding, int stride) {
32
+ const int dkernel = dilation * (filter_size - 1 ) + 1 ;
33
+ const int output_size = (input_size + 2 * padding - dkernel) / stride + 1 ;
33
34
return output_size;
34
35
}
36
+ inline bool IsExpand (std::vector<int64_t >& filter_dim,
37
+ std::vector<int >& strides, std::vector<int >& paddings,
38
+ std::vector<int >& dilations) {
39
+ bool filter_1 = true , strides_1 = true , padding_0 = true , dilation_1 = true ;
40
+ for (size_t j = 0 ; j < strides.size (); ++j) {
41
+ filter_1 = filter_1 && (static_cast <int >(filter_dim[j]) == 1 );
42
+ strides_1 = strides_1 && (strides[j] == 1 );
43
+ padding_0 = padding_0 && (paddings[j] == 0 );
44
+ dilation_1 = dilation_1 && (dilations[j] == 1 );
45
+ }
46
+ return !(filter_1 && strides_1 && padding_0 && dilation_1);
47
+ }
35
48
36
49
// Define Op classes in .h file so that other conv
37
50
// operator implementations can reuse the code.
@@ -50,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
50
63
class ConvOp : public framework ::OperatorWithKernel {
51
64
public:
52
65
using framework::OperatorWithKernel::OperatorWithKernel;
53
-
54
66
void InferShape (framework::InferShapeContext* ctx) const override ;
55
67
};
56
68
57
69
class ConvOpGrad : public framework ::OperatorWithKernel {
58
70
public:
59
71
using framework::OperatorWithKernel::OperatorWithKernel;
60
-
61
72
void InferShape (framework::InferShapeContext* ctx) const override ;
62
73
};
63
74
@@ -73,9 +84,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
73
84
Tensor* output = context.Output <Tensor>(" Output" );
74
85
output->mutable_data <T>(context.GetPlace ());
75
86
87
+ int groups = context.Attr <int >(" groups" );
76
88
std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
77
89
std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
78
- int groups = context.Attr <int >( " groups " );
90
+ std::vector< int > dilations = context.Attr <std::vector< int >>( " dilations " );
79
91
80
92
const int batch_size = static_cast <int >(input->dims ()[0 ]);
81
93
@@ -106,14 +118,17 @@ class GemmConvKernel : public framework::OpKernel<T> {
106
118
framework::DDim col_matrix_shape =
107
119
framework::flatten_to_2d (col_shape, filter_shape_vec.size () + 1 );
108
120
121
+ bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
109
122
Tensor col;
110
- col.mutable_data <T>(col_shape, context.GetPlace ());
111
123
// col_matrix shares the same piece of data with col,
112
124
// but will be reshaped into a two-dimensional matrix shape
113
125
// to call the matrix multiplication interface.
114
126
Tensor col_matrix;
115
- col_matrix.ShareDataWith (col);
116
- col_matrix.Resize (col_matrix_shape);
127
+ if (is_expand) {
128
+ col.mutable_data <T>(col_shape, context.GetPlace ());
129
+ col_matrix.ShareDataWith (col);
130
+ col_matrix.Resize (col_matrix_shape);
131
+ }
117
132
118
133
framework::DDim input_shape = framework::slice_ddim (
119
134
input->dims (), 1 , static_cast <int >(input->dims ().size ()));
@@ -130,24 +145,30 @@ class GemmConvKernel : public framework::OpKernel<T> {
130
145
int in_step = static_cast <int >(input->dims ()[1 ]) / groups;
131
146
int out_step = static_cast <int >(output->dims ()[1 ]) / groups;
132
147
148
+ math::Vol2ColFunctor<Place, T> vol2col;
149
+ math::Im2ColFunctor<math::ColFormat::kCFO , Place, T> im2col;
150
+
133
151
for (int i = 0 ; i < batch_size; i++) {
134
152
Tensor in_batch = input->Slice (i, i + 1 ).Resize (input_shape);
135
153
Tensor out_batch = output->Slice (i, i + 1 ).Resize (output_matrix_shape);
154
+
136
155
for (int g = 0 ; g < groups; g++) {
137
156
Tensor in_slice = in_batch.Slice (g * in_step, (g + 1 ) * in_step);
138
157
139
- if (filter_shape_vec.size () == 2 ) {
158
+ if (!is_expand) {
159
+ col.ShareDataWith (in_slice);
160
+ col_matrix.ShareDataWith (col);
161
+ col_matrix.Resize (col_matrix_shape);
162
+ } else if (filter_shape_vec.size () == 2 ) {
140
163
// im2col
141
- math::Im2ColFunctor<math::ColFormat:: kCFO , Place, T> im2col;
142
- im2col (context. device_context (), in_slice, col, strides [0 ],
143
- strides[ 1 ], paddings[ 0 ], paddings[ 0 ], paddings[1 ],
144
- paddings[ 1 ] );
164
+ im2col (context. device_context (), in_slice, dilations, strides,
165
+ std::vector< int >{paddings[ 0 ], paddings[ 1 ], paddings [0 ],
166
+ paddings[1 ]} ,
167
+ &col );
145
168
} else if (filter_shape_vec.size () == 3 ) {
146
169
// vol2col
147
- math::Vol2ColFunctor<Place, T> vol2col;
148
- vol2col (context.device_context (), in_slice, col, strides[0 ],
149
- strides[1 ], strides[2 ], paddings[0 ], paddings[1 ],
150
- paddings[2 ]);
170
+ vol2col (context.device_context (), in_slice, dilations, strides,
171
+ paddings, &col);
151
172
}
152
173
153
174
// gemm
@@ -178,9 +199,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
178
199
179
200
if (!input_grad && !filter_grad) return ;
180
201
202
+ int groups = context.Attr <int >(" groups" );
181
203
std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
182
204
std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
183
- int groups = context.Attr <int >( " groups " );
205
+ std::vector< int > dilations = context.Attr <std::vector< int >>( " dilations " );
184
206
185
207
const int batch_size = static_cast <int >(input->dims ()[0 ]);
186
208
@@ -230,21 +252,27 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
230
252
int in_step = static_cast <int >(input->dims ()[1 ]) / groups;
231
253
int out_step = static_cast <int >(output_grad->dims ()[1 ]) / groups;
232
254
255
+ bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
233
256
Tensor col;
234
257
// col_matrix shares the same piece of data with col,
235
258
// but will be reshaped into a two-dimensional matrix shape
236
259
// to call the matrix multiplication interface.
237
260
Tensor col_matrix;
238
- col.mutable_data <T>(col_shape, context.GetPlace ());
239
- col_matrix.ShareDataWith (col);
240
- col_matrix.Resize (col_matrix_shape);
261
+ if (is_expand) {
262
+ col.mutable_data <T>(col_shape, context.GetPlace ());
263
+ col_matrix.ShareDataWith (col);
264
+ col_matrix.Resize (col_matrix_shape);
265
+ }
241
266
242
267
math::SetConstant<Place, T> set_zero;
243
268
244
269
if (input_grad) {
245
270
input_grad->mutable_data <T>(context.GetPlace ());
246
271
set_zero (context.device_context (), input_grad, static_cast <T>(0 ));
247
272
273
+ math::Col2VolFunctor<Place, T> col2vol;
274
+ math::Col2ImFunctor<math::ColFormat::kCFO , Place, T> col2im;
275
+
248
276
for (int i = 0 ; i < batch_size; i++) {
249
277
Tensor out_grad_batch =
250
278
output_grad->Slice (i, i + 1 ).Resize (output_matrix_shape);
@@ -254,24 +282,26 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
254
282
Tensor out_grad_slice =
255
283
out_grad_batch.Slice (g * out_step, (g + 1 ) * out_step);
256
284
Tensor filter_slice = filter.Slice (g * out_step, (g + 1 ) * out_step);
257
- math::matmul<Place, T>(context.device_context (), filter_slice, true ,
258
- out_grad_slice, false , T (1.0 ), &col_matrix,
259
- T (0.0 ));
260
- // col2im
285
+
261
286
Tensor in_grad_slice =
262
287
in_grad_batch.Slice (g * in_step, (g + 1 ) * in_step);
263
288
264
- if (filter_shape_vec.size () == 2 ) {
265
- math::Col2ImFunctor<math::ColFormat::kCFO , Place, T> col2im;
266
- col2im (context.device_context (), in_grad_slice, col, strides[0 ],
267
- strides[1 ], paddings[0 ], paddings[0 ], paddings[1 ],
268
- paddings[1 ]);
289
+ if (!is_expand) {
290
+ col_matrix.ShareDataWith (in_grad_slice);
291
+ col_matrix.Resize (col_matrix_shape);
292
+ }
293
+ math::matmul<Place, T>(context.device_context (), filter_slice, true ,
294
+ out_grad_slice, false , T (1.0 ), &col_matrix,
295
+ T (0.0 ));
269
296
270
- } else if (filter_shape_vec.size () == 3 ) {
271
- math::Col2VolFunctor<Place, T> col2vol;
272
- col2vol (context.device_context (), in_grad_slice, col, strides[0 ],
273
- strides[1 ], strides[2 ], paddings[0 ], paddings[1 ],
274
- paddings[2 ]);
297
+ if (is_expand && filter_shape_vec.size () == 2 ) {
298
+ col2im (context.device_context (), col, dilations, strides,
299
+ std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
300
+ paddings[1 ]},
301
+ &in_grad_slice);
302
+ } else if (is_expand && filter_shape_vec.size () == 3 ) {
303
+ col2vol (context.device_context (), col, dilations, strides, paddings,
304
+ &in_grad_slice);
275
305
}
276
306
}
277
307
}
@@ -282,7 +312,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
282
312
Tensor filter_grad_ = *filter_grad;
283
313
filter_grad_.Resize (filter_matrix_shape);
284
314
set_zero (context.device_context (), filter_grad, static_cast <T>(0 ));
285
-
315
+ math::Im2ColFunctor<math::ColFormat::kCFO , Place, T> im2col;
316
+ math::Vol2ColFunctor<Place, T> vol2col;
286
317
for (int i = 0 ; i < batch_size; i++) {
287
318
Tensor out_grad_batch =
288
319
output_grad->Slice (i, i + 1 ).Resize (output_matrix_shape);
@@ -293,16 +324,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
293
324
out_grad_batch.Slice (g * out_step, (g + 1 ) * out_step);
294
325
Tensor in_slice = in_batch.Slice (g * in_step, (g + 1 ) * in_step);
295
326
296
- if (filter_shape_vec.size () == 2 ) {
297
- math::Im2ColFunctor<math::ColFormat::kCFO , Place, T> im2col;
298
- im2col (context.device_context (), in_slice, col, strides[0 ],
299
- strides[1 ], paddings[0 ], paddings[0 ], paddings[1 ],
300
- paddings[1 ]);
327
+ if (!is_expand) {
328
+ col.ShareDataWith (in_slice);
329
+ col_matrix.ShareDataWith (col);
330
+ col_matrix.Resize (col_matrix_shape);
331
+ } else if (filter_shape_vec.size () == 2 ) {
332
+ im2col (context.device_context (), in_slice, dilations, strides,
333
+ std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
334
+ paddings[1 ]},
335
+ &col);
301
336
} else if (filter_shape_vec.size () == 3 ) {
302
- math::Vol2ColFunctor<Place, T> vol2col;
303
- vol2col (context.device_context (), in_slice, col, strides[0 ],
304
- strides[1 ], strides[2 ], paddings[0 ], paddings[1 ],
305
- paddings[2 ]);
337
+ vol2col (context.device_context (), in_slice, dilations, strides,
338
+ paddings, &col);
306
339
}
307
340
308
341
// gemm
0 commit comments