Skip to content

Commit e800c0d

Browse files
author
chengduo
authored
Merge pull request #5791 from chengduoZH/fix_conv_op
remove vector::erase
2 parents d883547 + a93227a commit e800c0d

File tree

2 files changed

+51
-67
lines changed

2 files changed

+51
-67
lines changed

paddle/operators/conv_op.h

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
3838
std::vector<int>& dilations) {
3939
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
4040
for (size_t j = 0; j < strides.size(); ++j) {
41-
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
41+
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
4242
strides_1 = strides_1 && (strides[j] == 1);
4343
padding_0 = padding_0 && (paddings[j] == 0);
4444
dilation_1 = dilation_1 && (dilations[j] == 1);
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
9191

9292
const int batch_size = static_cast<int>(input->dims()[0]);
9393

94-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
94+
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
9595
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
96-
filter_shape_vec.erase(filter_shape_vec.begin(),
97-
filter_shape_vec.begin() + 2);
98-
99-
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
96+
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
10097
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
101-
output_shape_vec.erase(output_shape_vec.begin(),
102-
output_shape_vec.begin() + 2);
10398

10499
// use col_shape in the im2col calculation
105100
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
106101
// o_h, o_w}
107-
std::vector<int64_t> col_shape_vec;
108-
col_shape_vec.push_back(input->dims()[1] / groups);
109-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
110-
filter_shape_vec.end());
111-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
112-
output_shape_vec.end());
102+
size_t data_dim = filter_shape_vec.size() - 2;
103+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
104+
col_shape_vec[0] = input->dims()[1] / groups;
105+
for (size_t j = 0; j < data_dim; ++j) {
106+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
107+
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
108+
}
113109
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
114110

115111
// use col_matrix_shape in the gemm calculation
116112
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
117113
// o_h * o_w)
118114
framework::DDim col_matrix_shape =
119-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
115+
framework::flatten_to_2d(col_shape, data_dim + 1);
120116

121117
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
122118
Tensor col;
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
159155
col.ShareDataWith(in_slice);
160156
col_matrix.ShareDataWith(col);
161157
col_matrix.Resize(col_matrix_shape);
162-
} else if (filter_shape_vec.size() == 2) {
158+
} else if (data_dim == 2U) {
163159
// im2col
164160
im2col(context.device_context(), in_slice, dilations, strides,
165161
std::vector<int>{paddings[0], paddings[1], paddings[0],
166162
paddings[1]},
167163
&col);
168-
} else if (filter_shape_vec.size() == 3) {
164+
} else if (data_dim == 3U) {
169165
// vol2col
170166
vol2col(context.device_context(), in_slice, dilations, strides,
171167
paddings, &col);
@@ -206,34 +202,30 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
206202

207203
const int batch_size = static_cast<int>(input->dims()[0]);
208204

209-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
205+
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
210206
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
211-
filter_shape_vec.erase(filter_shape_vec.begin(),
212-
filter_shape_vec.begin() + 2);
213-
214-
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
207+
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
215208
std::vector<int64_t> output_shape_vec(
216209
framework::vectorize(output_grad->dims()));
217-
output_shape_vec.erase(output_shape_vec.begin(),
218-
output_shape_vec.begin() + 2);
219210

220211
// use col_shape in the im2col calculation
221212
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
222213
// o_h, o_w}
223-
std::vector<int64_t> col_shape_vec;
224-
col_shape_vec.push_back(input->dims()[1] / groups);
225-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
226-
filter_shape_vec.end());
227-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
228-
output_shape_vec.end());
214+
size_t data_dim = filter_shape_vec.size() - 2;
215+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
216+
col_shape_vec[0] = input->dims()[1] / groups;
217+
for (size_t j = 0; j < data_dim; ++j) {
218+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
219+
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
220+
}
229221
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
230222

231223
// use col_matrix_shape in the gemm calculation
232224
// size: (i_c/g * k_h * k_w, o_h * o_w)
233225
// or
234226
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
235227
framework::DDim col_matrix_shape =
236-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
228+
framework::flatten_to_2d(col_shape, data_dim + 1);
237229

238230
framework::DDim input_shape = framework::slice_ddim(
239231
input->dims(), 1, static_cast<int>(input->dims().size()));
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
294286
out_grad_slice, false, T(1.0), &col_matrix,
295287
T(0.0));
296288

297-
if (is_expand && filter_shape_vec.size() == 2) {
289+
if (is_expand && data_dim == 2U) {
298290
col2im(context.device_context(), col, dilations, strides,
299291
std::vector<int>{paddings[0], paddings[1], paddings[0],
300292
paddings[1]},
301293
&in_grad_slice);
302-
} else if (is_expand && filter_shape_vec.size() == 3) {
294+
} else if (is_expand && data_dim == 3U) {
303295
col2vol(context.device_context(), col, dilations, strides, paddings,
304296
&in_grad_slice);
305297
}
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
328320
col.ShareDataWith(in_slice);
329321
col_matrix.ShareDataWith(col);
330322
col_matrix.Resize(col_matrix_shape);
331-
} else if (filter_shape_vec.size() == 2) {
323+
} else if (data_dim == 2U) {
332324
im2col(context.device_context(), in_slice, dilations, strides,
333325
std::vector<int>{paddings[0], paddings[1], paddings[0],
334326
paddings[1]},
335327
&col);
336-
} else if (filter_shape_vec.size() == 3) {
328+
} else if (data_dim == 3U) {
337329
vol2col(context.device_context(), in_slice, dilations, strides,
338330
paddings, &col);
339331
}

paddle/operators/conv_transpose_op.h

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,26 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
6868

6969
const int batch_size = static_cast<int>(input->dims()[0]);
7070

71-
// input_shape_vec: {h, w} or {d, h, w}
71+
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
7272
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
73-
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
74-
75-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
73+
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
7674
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
77-
filter_shape_vec.erase(filter_shape_vec.begin(),
78-
filter_shape_vec.begin() + 2);
7975

8076
// use col_shape in the im2col and col2im (or vol2col and col2vol)
8177
// calculation
8278
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
83-
std::vector<int64_t> col_shape_vec;
84-
col_shape_vec.push_back(output->dims()[1]);
85-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
86-
filter_shape_vec.end());
87-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
88-
input_shape_vec.end());
79+
size_t data_dim = filter_shape_vec.size() - 2;
80+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
81+
col_shape_vec[0] = output->dims()[1];
82+
for (size_t j = 0; j < data_dim; ++j) {
83+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
84+
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
85+
}
8986
DDim col_shape(framework::make_ddim(col_shape_vec));
9087

9188
// use col_matrix_shape in the gemm calculation
9289
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
93-
DDim col_matrix_shape =
94-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
90+
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
9591

9692
Tensor col;
9793
col.mutable_data<T>(col_shape, context.GetPlace());
@@ -136,15 +132,15 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
136132
input_batch, false, static_cast<T>(1.0),
137133
&col_matrix, static_cast<T>(0.0));
138134

139-
if (filter_shape_vec.size() == 2) {
135+
if (data_dim == 2U) {
140136
// col2im: col_matrix -> dy
141137
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
142138
col2im(context.device_context(), col,
143139
std::vector<int>{dilations[0], dilations[1]}, strides,
144140
std::vector<int>{paddings[0], paddings[1], paddings[0],
145141
paddings[1]},
146142
&output_batch);
147-
} else if (filter_shape_vec.size() == 3) {
143+
} else if (data_dim == 3U) {
148144
// col2vol: col_matrix -> dy
149145
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
150146
col2vol(context.device_context(), col, dilations, strides, paddings,
@@ -176,30 +172,26 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
176172

177173
const int batch_size = static_cast<int>(input->dims()[0]);
178174

179-
// input_shape_vec: {h, w} or {d, h, w}
175+
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
180176
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
181-
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
182-
183-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
177+
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
184178
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
185-
filter_shape_vec.erase(filter_shape_vec.begin(),
186-
filter_shape_vec.begin() + 2);
187179

188180
// use col_shape in the im2col and col2im (or vol2col and col2vol)
189181
// calculation
190182
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
191-
std::vector<int64_t> col_shape_vec;
192-
col_shape_vec.push_back(output_grad->dims()[1]);
193-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
194-
filter_shape_vec.end());
195-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
196-
input_shape_vec.end());
183+
size_t data_dim = filter_shape_vec.size() - 2;
184+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
185+
col_shape_vec[0] = output_grad->dims()[1];
186+
for (size_t j = 0; j < data_dim; ++j) {
187+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
188+
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
189+
}
197190
DDim col_shape(framework::make_ddim(col_shape_vec));
198191

199192
// use col_matrix_shape in the gemm calculation
200193
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
201-
DDim col_matrix_shape =
202-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
194+
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
203195

204196
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
205197
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
@@ -248,15 +240,15 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
248240
Tensor output_grad_batch =
249241
output_grad->Slice(i, i + 1).Resize(output_shape);
250242

251-
if (filter_shape_vec.size() == 2) {
243+
if (data_dim == 2U) {
252244
// im2col: dy -> col matrix
253245
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
254246
im2col(context.device_context(), output_grad_batch,
255247
std::vector<int>{dilations[0], dilations[1]}, strides,
256248
std::vector<int>{paddings[0], paddings[1], paddings[0],
257249
paddings[1]},
258250
&col);
259-
} else if (filter_shape_vec.size() == 3) {
251+
} else if (data_dim == 3U) {
260252
// vol2col: dy -> col_matrix
261253
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
262254
vol2col(context.device_context(), output_grad_batch, dilations,

0 commit comments

Comments
 (0)