Skip to content

Commit a93227a

Browse files
committed
refine code
1 parent e5bf9c5 commit a93227a

File tree

2 files changed

+42
-44
lines changed

2 files changed

+42
-44
lines changed

paddle/operators/conv_op.h

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ class GemmConvKernel : public framework::OpKernel<T> {
9999
// use col_shape in the im2col calculation
100100
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
101101
// o_h, o_w}
102-
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
103-
output_shape_vec.size() - 3);
104-
col_shape_vec.assign(1, input->dims()[1] / groups);
105-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
106-
filter_shape_vec.end());
107-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
108-
output_shape_vec.end());
102+
size_t data_dim = filter_shape_vec.size() - 2;
103+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
104+
col_shape_vec[0] = input->dims()[1] / groups;
105+
for (size_t j = 0; j < data_dim; ++j) {
106+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
107+
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
108+
}
109109
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
110110

111111
// use col_matrix_shape in the gemm calculation
112112
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
113113
// o_h * o_w)
114114
framework::DDim col_matrix_shape =
115-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
115+
framework::flatten_to_2d(col_shape, data_dim + 1);
116116

117117
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
118118
Tensor col;
@@ -155,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
155155
col.ShareDataWith(in_slice);
156156
col_matrix.ShareDataWith(col);
157157
col_matrix.Resize(col_matrix_shape);
158-
} else if (filter_shape_vec.size() == 4) {
158+
} else if (data_dim == 2U) {
159159
// im2col
160160
im2col(context.device_context(), in_slice, dilations, strides,
161161
std::vector<int>{paddings[0], paddings[1], paddings[0],
162162
paddings[1]},
163163
&col);
164-
} else if (filter_shape_vec.size() == 5) {
164+
} else if (data_dim == 3U) {
165165
// vol2col
166166
vol2col(context.device_context(), in_slice, dilations, strides,
167167
paddings, &col);
@@ -211,21 +211,21 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
211211
// use col_shape in the im2col calculation
212212
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
213213
// o_h, o_w}
214-
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
215-
output_shape_vec.size() - 3);
216-
col_shape_vec.assign(1, input->dims()[1] / groups);
217-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
218-
filter_shape_vec.end());
219-
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
220-
output_shape_vec.end());
214+
size_t data_dim = filter_shape_vec.size() - 2;
215+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
216+
col_shape_vec[0] = input->dims()[1] / groups;
217+
for (size_t j = 0; j < data_dim; ++j) {
218+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
219+
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
220+
}
221221
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
222222

223223
// use col_matrix_shape in the gemm calculation
224224
// size: (i_c/g * k_h * k_w, o_h * o_w)
225225
// or
226226
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
227227
framework::DDim col_matrix_shape =
228-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
228+
framework::flatten_to_2d(col_shape, data_dim + 1);
229229

230230
framework::DDim input_shape = framework::slice_ddim(
231231
input->dims(), 1, static_cast<int>(input->dims().size()));
@@ -286,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
286286
out_grad_slice, false, T(1.0), &col_matrix,
287287
T(0.0));
288288

289-
if (is_expand && filter_shape_vec.size() == 4) {
289+
if (is_expand && data_dim == 2U) {
290290
col2im(context.device_context(), col, dilations, strides,
291291
std::vector<int>{paddings[0], paddings[1], paddings[0],
292292
paddings[1]},
293293
&in_grad_slice);
294-
} else if (is_expand && filter_shape_vec.size() == 5) {
294+
} else if (is_expand && data_dim == 3U) {
295295
col2vol(context.device_context(), col, dilations, strides, paddings,
296296
&in_grad_slice);
297297
}
@@ -320,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
320320
col.ShareDataWith(in_slice);
321321
col_matrix.ShareDataWith(col);
322322
col_matrix.Resize(col_matrix_shape);
323-
} else if (filter_shape_vec.size() == 4) {
323+
} else if (data_dim == 2U) {
324324
im2col(context.device_context(), in_slice, dilations, strides,
325325
std::vector<int>{paddings[0], paddings[1], paddings[0],
326326
paddings[1]},
327327
&col);
328-
} else if (filter_shape_vec.size() == 5) {
328+
} else if (data_dim == 3U) {
329329
vol2col(context.device_context(), in_slice, dilations, strides,
330330
paddings, &col);
331331
}

paddle/operators/conv_transpose_op.h

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,18 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
7676
// use col_shape in the im2col and col2im (or vol2col and col2vol)
7777
// calculation
7878
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
79-
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
80-
input_shape_vec.size() - 3);
81-
col_shape_vec.assign(1, output->dims()[1]);
82-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
83-
filter_shape_vec.end());
84-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2,
85-
input_shape_vec.end());
79+
size_t data_dim = filter_shape_vec.size() - 2;
80+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
81+
col_shape_vec[0] = output->dims()[1];
82+
for (size_t j = 0; j < data_dim; ++j) {
83+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
84+
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
85+
}
8686
DDim col_shape(framework::make_ddim(col_shape_vec));
8787

8888
// use col_matrix_shape in the gemm calculation
8989
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
90-
DDim col_matrix_shape =
91-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
90+
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
9291

9392
Tensor col;
9493
col.mutable_data<T>(col_shape, context.GetPlace());
@@ -133,15 +132,15 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
133132
input_batch, false, static_cast<T>(1.0),
134133
&col_matrix, static_cast<T>(0.0));
135134

136-
if (filter_shape_vec.size() == 4) {
135+
if (data_dim == 2U) {
137136
// col2im: col_matrix -> dy
138137
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
139138
col2im(context.device_context(), col,
140139
std::vector<int>{dilations[0], dilations[1]}, strides,
141140
std::vector<int>{paddings[0], paddings[1], paddings[0],
142141
paddings[1]},
143142
&output_batch);
144-
} else if (filter_shape_vec.size() == 5) {
143+
} else if (data_dim == 3U) {
145144
// col2vol: col_matrix -> dy
146145
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
147146
col2vol(context.device_context(), col, dilations, strides, paddings,
@@ -181,19 +180,18 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
181180
// use col_shape in the im2col and col2im (or vol2col and col2vol)
182181
// calculation
183182
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
184-
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
185-
input_shape_vec.size() - 3);
186-
col_shape_vec.assign(1, output_grad->dims()[1]);
187-
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
188-
filter_shape_vec.end());
189-
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2,
190-
input_shape_vec.end());
183+
size_t data_dim = filter_shape_vec.size() - 2;
184+
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
185+
col_shape_vec[0] = output_grad->dims()[1];
186+
for (size_t j = 0; j < data_dim; ++j) {
187+
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
188+
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
189+
}
191190
DDim col_shape(framework::make_ddim(col_shape_vec));
192191

193192
// use col_matrix_shape in the gemm calculation
194193
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
195-
DDim col_matrix_shape =
196-
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
194+
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
197195

198196
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
199197
DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
@@ -242,15 +240,15 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
242240
Tensor output_grad_batch =
243241
output_grad->Slice(i, i + 1).Resize(output_shape);
244242

245-
if (filter_shape_vec.size() == 4) {
243+
if (data_dim == 2U) {
246244
// im2col: dy -> col matrix
247245
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
248246
im2col(context.device_context(), output_grad_batch,
249247
std::vector<int>{dilations[0], dilations[1]}, strides,
250248
std::vector<int>{paddings[0], paddings[1], paddings[0],
251249
paddings[1]},
252250
&col);
253-
} else if (filter_shape_vec.size() == 5) {
251+
} else if (data_dim == 3U) {
254252
// vol2col: dy -> col_matrix
255253
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
256254
vol2col(context.device_context(), output_grad_batch, dilations,

0 commit comments

Comments
 (0)