Skip to content

Commit e5bf9c5

Browse files
committed
remove vector::eraze
1 parent e930f49 commit e5bf9c5

File tree

2 files changed

+43
-57
lines changed

2 files changed

+43
-57
lines changed

paddle/operators/conv_op.h

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
3838
std::vector<int>& dilations) {
3939
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
4040
for (size_t j = 0; j < strides.size(); ++j) {
41-
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
41+
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
4242
strides_1 = strides_1 && (strides[j] == 1);
4343
padding_0 = padding_0 && (paddings[j] == 0);
4444
dilation_1 = dilation_1 && (dilations[j] == 1);
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
9191

9292
const int batch_size = static_cast<int>(input->dims()[0]);
9393

94-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
94+
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
9595
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
96-
filter_shape_vec.erase(filter_shape_vec.begin(),
97-
filter_shape_vec.begin() + 2);
98-
99-
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
96+
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
10097
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
101-
output_shape_vec.erase(output_shape_vec.begin(),
102-
output_shape_vec.begin() + 2);
10398

10499
// use col_shape in the im2col calculation
105100
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
106101
// o_h, o_w}
107-
std::vector<int64_t> col_shape_vec;
108-
col_shape_vec.push_back(input->dims()[1] / groups);
109-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
102+
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
103+
output_shape_vec.size() - 3);
104+
col_shape_vec.assign(1, input->dims()[1] / groups);
105+
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
110106
filter_shape_vec.end());
111-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
107+
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
112108
output_shape_vec.end());
113109
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
114110

115111
// use col_matrix_shape in the gemm calculation
116112
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
117113
// o_h * o_w)
118114
framework::DDim col_matrix_shape =
119-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
115+
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
120116

121117
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
122118
Tensor col;
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
159155
col.ShareDataWith(in_slice);
160156
col_matrix.ShareDataWith(col);
161157
col_matrix.Resize(col_matrix_shape);
162-
} else if (filter_shape_vec.size() == 2) {
158+
} else if (filter_shape_vec.size() == 4) {
163159
// im2col
164160
im2col(context.device_context(), in_slice, dilations, strides,
165161
std::vector<int>{paddings[0], paddings[1], paddings[0],
166162
paddings[1]},
167163
&col);
168-
} else if (filter_shape_vec.size() == 3) {
164+
} else if (filter_shape_vec.size() == 5) {
169165
// vol2col
170166
vol2col(context.device_context(), in_slice, dilations, strides,
171167
paddings, &col);
@@ -206,25 +202,21 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
206202

207203
const int batch_size = static_cast<int>(input->dims()[0]);
208204

209-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
205+
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
210206
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
211-
filter_shape_vec.erase(filter_shape_vec.begin(),
212-
filter_shape_vec.begin() + 2);
213-
214-
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
207+
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
215208
std::vector<int64_t> output_shape_vec(
216209
framework::vectorize(output_grad->dims()));
217-
output_shape_vec.erase(output_shape_vec.begin(),
218-
output_shape_vec.begin() + 2);
219210

220211
// use col_shape in the im2col calculation
221212
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
222213
// o_h, o_w}
223-
std::vector<int64_t> col_shape_vec;
224-
col_shape_vec.push_back(input->dims()[1] / groups);
225-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
214+
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
215+
output_shape_vec.size() - 3);
216+
col_shape_vec.assign(1, input->dims()[1] / groups);
217+
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
226218
filter_shape_vec.end());
227-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
219+
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
228220
output_shape_vec.end());
229221
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
230222

@@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
233225
// or
234226
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
235227
framework::DDim col_matrix_shape =
236-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
228+
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
237229

238230
framework::DDim input_shape = framework::slice_ddim(
239231
input->dims(), 1, static_cast<int>(input->dims().size()));
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
294286
out_grad_slice, false, T(1.0), &col_matrix,
295287
T(0.0));
296288

297-
if (is_expand && filter_shape_vec.size() == 2) {
289+
if (is_expand && filter_shape_vec.size() == 4) {
298290
col2im(context.device_context(), col, dilations, strides,
299291
std::vector<int>{paddings[0], paddings[1], paddings[0],
300292
paddings[1]},
301293
&in_grad_slice);
302-
} else if (is_expand && filter_shape_vec.size() == 3) {
294+
} else if (is_expand && filter_shape_vec.size() == 5) {
303295
col2vol(context.device_context(), col, dilations, strides, paddings,
304296
&in_grad_slice);
305297
}
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
328320
col.ShareDataWith(in_slice);
329321
col_matrix.ShareDataWith(col);
330322
col_matrix.Resize(col_matrix_shape);
331-
} else if (filter_shape_vec.size() == 2) {
323+
} else if (filter_shape_vec.size() == 4) {
332324
im2col(context.device_context(), in_slice, dilations, strides,
333325
std::vector<int>{paddings[0], paddings[1], paddings[0],
334326
paddings[1]},
335327
&col);
336-
} else if (filter_shape_vec.size() == 3) {
328+
} else if (filter_shape_vec.size() == 5) {
337329
vol2col(context.device_context(), in_slice, dilations, strides,
338330
paddings, &col);
339331
}

paddle/operators/conv_transpose_op.h

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
6868

6969
const int batch_size = static_cast<int>(input->dims()[0]);
7070

71-
// input_shape_vec: {h, w} or {d, h, w}
71+
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
7272
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
73-
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
74-
75-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
73+
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
7674
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
77-
filter_shape_vec.erase(filter_shape_vec.begin(),
78-
filter_shape_vec.begin() + 2);
7975

8076
// use col_shape in the im2col and col2im (or vol2col and col2vol)
8177
// calculation
8278
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
83-
std::vector<int64_t> col_shape_vec;
84-
col_shape_vec.push_back(output->dims()[1]);
85-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
79+
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
80+
input_shape_vec.size() - 3);
81+
col_shape_vec.assign(1, output->dims()[1]);
82+
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
8683
filter_shape_vec.end());
87-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
84+
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2,
8885
input_shape_vec.end());
8986
DDim col_shape(framework::make_ddim(col_shape_vec));
9087

9188
// use col_matrix_shape in the gemm calculation
9289
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
9390
DDim col_matrix_shape =
94-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
91+
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
9592

9693
Tensor col;
9794
col.mutable_data<T>(col_shape, context.GetPlace());
@@ -136,15 +133,15 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
136133
input_batch, false, static_cast<T>(1.0),
137134
&col_matrix, static_cast<T>(0.0));
138135

139-
if (filter_shape_vec.size() == 2) {
136+
if (filter_shape_vec.size() == 4) {
140137
// col2im: col_matrix -> dy
141138
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
142139
col2im(context.device_context(), col,
143140
std::vector<int>{dilations[0], dilations[1]}, strides,
144141
std::vector<int>{paddings[0], paddings[1], paddings[0],
145142
paddings[1]},
146143
&output_batch);
147-
} else if (filter_shape_vec.size() == 3) {
144+
} else if (filter_shape_vec.size() == 5) {
148145
// col2vol: col_matrix -> dy
149146
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
150147
col2vol(context.device_context(), col, dilations, strides, paddings,
@@ -176,30 +173,27 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
176173

177174
const int batch_size = static_cast<int>(input->dims()[0]);
178175

179-
// input_shape_vec: {h, w} or {d, h, w}
176+
// input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
180177
std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
181-
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
182-
183-
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
178+
// filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
184179
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
185-
filter_shape_vec.erase(filter_shape_vec.begin(),
186-
filter_shape_vec.begin() + 2);
187180

188181
// use col_shape in the im2col and col2im (or vol2col and col2vol)
189182
// calculation
190183
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
191-
std::vector<int64_t> col_shape_vec;
192-
col_shape_vec.push_back(output_grad->dims()[1]);
193-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
184+
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
185+
input_shape_vec.size() - 3);
186+
col_shape_vec.assign(1, output_grad->dims()[1]);
187+
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
194188
filter_shape_vec.end());
195-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
189+
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2,
196190
input_shape_vec.end());
197191
DDim col_shape(framework::make_ddim(col_shape_vec));
198192

199193
// use col_matrix_shape in the gemm calculation
200194
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
201195
DDim col_matrix_shape =
202-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
196+
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
203197

204198
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
205199
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
@@ -248,15 +242,15 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
248242
Tensor output_grad_batch =
249243
output_grad->Slice(i, i + 1).Resize(output_shape);
250244

251-
if (filter_shape_vec.size() == 2) {
245+
if (filter_shape_vec.size() == 4) {
252246
// im2col: dy -> col matrix
253247
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
254248
im2col(context.device_context(), output_grad_batch,
255249
std::vector<int>{dilations[0], dilations[1]}, strides,
256250
std::vector<int>{paddings[0], paddings[1], paddings[0],
257251
paddings[1]},
258252
&col);
259-
} else if (filter_shape_vec.size() == 3) {
253+
} else if (filter_shape_vec.size() == 5) {
260254
// vol2col: dy -> col_matrix
261255
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
262256
vol2col(context.device_context(), output_grad_batch, dilations,

0 commit comments

Comments
 (0)