Skip to content

Commit 200a02e

Browse files
author
chengduo
authored
Merge pull request #5041 from chengduoZH/fix_im2col_interface
fix im2col interface
2 parents 25588a3 + 61dbf4b commit 200a02e

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

paddle/operators/conv2dtranspose_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
120120
math::matmul<Place, T>(context.device_context(), filter, true,
121121
input_batch, false, T(1.0), &col_matrix, T(0.0));
122122
col2im(context.device_context(), output_batch, col, strides[0],
123-
strides[1], 0, 0);
123+
strides[1], 0, 0, 0, 0);
124124
}
125125
}
126126
};
@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
206206

207207
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
208208
im2col(context.device_context(), output_grad_batch, col, strides[0],
209-
strides[1], paddings[0], paddings[1]);
209+
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
210210

211211
// gemm: dx = filter * dy
212212
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
238238

239239
// im2col: (c * h * w, k_h * k_w)
240240
im2col(context.device_context(), output_grad_batch, col, strides[0],
241-
strides[1], paddings[0], paddings[1]);
241+
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
242242

243243
// gemm: d_filter = x * y_grad^T
244244
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)

paddle/operators/math/im2col.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ class Im2ColFunctor {
7575
void operator()(const platform::DeviceContext& context,
7676
const framework::Tensor& im, framework::Tensor& col,
7777
int stride_height, int stride_width, int padding_up,
78-
int padding_down, int padding_left = 0,
79-
int padding_right = 0);
78+
int padding_down, int padding_left, int padding_right);
8079
};
8180

8281
template <ColFormat Format, typename Place, typename T>
@@ -85,7 +84,7 @@ class Col2ImFunctor {
8584
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
8685
const framework::Tensor& col, int stride_height,
8786
int stride_width, int padding_up, int padding_down,
88-
int padding_left = 0, int padding_right = 0);
87+
int padding_left, int padding_right);
8988
};
9089

9190
} // namespace math

0 commit comments

Comments
 (0)