Skip to content

Commit 4bafbf4

Browse files
author
Yibing Liu
committed
Enable groups for conv3d transpose op
1 parent adbf97b commit 4bafbf4

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
5050
"dimension should be the same.");
5151
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
5252
"In ConvTransposeOp, The number of input channels should "
53-
"be equal to the number of filter' channels.");
53+
"be equal to the number of filter's channels.");
5454

5555
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
5656
for (size_t i = 0; i < strides.size(); ++i) {
@@ -208,6 +208,10 @@ void Conv3DTransposeOpMaker::Make() {
208208
"(vector<int> default:{0, 0, 0}), paddings(d_pad, "
209209
"h_pad, w_pad) of convolution transpose operator.")
210210
.SetDefault({0, 0, 0});
211+
AddAttr<int>("groups",
212+
"(int default:1), the groups number of the convolution3d "
213+
"transpose operator. ")
214+
.SetDefault(1);
211215
AddAttr<bool>(
212216
"use_cudnn",
213217
"(bool, default false) Only used in cudnn kernel, need install cudnn")

python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121

2222
def conv3dtranspose_forward_naive(input_, filter_, attrs):
2323
in_n, in_c, in_d, in_h, in_w = input_.shape
24-
f_c, out_c, f_d, f_h, f_w = filter_.shape
24+
f_c, f_out_c, f_d, f_h, f_w = filter_.shape
25+
groups = attrs['groups']
2526
assert in_c == f_c
27+
out_c = f_out_c * groups
28+
sub_in_c = in_c / groups
2629

2730
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
2831
'dilations']
@@ -39,18 +42,23 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
3942
for d in range(in_d):
4043
for i in range(in_h):
4144
for j in range(in_w):
42-
input_masked = input_[n, :, d, i, j] # (c)
43-
input_masked = np.reshape(input_masked, (in_c, 1, 1, 1))
44-
input_masked = np.tile(input_masked, (1, f_d, f_h, f_w))
45-
46-
for k in range(out_c):
47-
tmp_out = np.sum(input_masked * filter_[:, k, :, :, :],
48-
axis=0)
49-
d1, d2 = d * stride[0], d * stride[0] + d_bolck_d
50-
i1, i2 = i * stride[1], i * stride[1] + d_bolck_h
51-
j1, j2 = j * stride[2], j * stride[2] + d_bolck_w
52-
out[n, k, d1:d2:dilations[0], i1:i2:dilations[1], j1:j2:
53-
dilations[2]] += tmp_out
45+
for g in range(groups):
46+
input_masked = input_[n, g * sub_in_c:(g + 1
47+
) * sub_in_c, d,
48+
i, j] # (c)
49+
input_masked = np.reshape(input_masked,
50+
(sub_in_c, 1, 1, 1))
51+
input_masked = np.tile(input_masked, (1, f_d, f_h, f_w))
52+
53+
for k in range(f_out_c):
54+
tmp_out = np.sum(input_masked * filter_[
55+
g * sub_in_c:(g + 1) * sub_in_c, k, :, :, :],
56+
axis=0)
57+
d1, d2 = d * stride[0], d * stride[0] + d_bolck_d
58+
i1, i2 = i * stride[1], i * stride[1] + d_bolck_h
59+
j1, j2 = j * stride[2], j * stride[2] + d_bolck_w
60+
out[n, g * f_out_c + k, d1:d2:dilations[0], i1:i2:
61+
dilations[1], j1:j2:dilations[2]] += tmp_out
5462

5563
out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w -
5664
pad[2]]
@@ -72,6 +80,7 @@ def setUp(self):
7280
'strides': self.stride,
7381
'paddings': self.pad,
7482
'dilations': self.dilations,
83+
'groups': self.groups,
7584
'use_cudnn': self.use_cudnn,
7685
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
7786
}
@@ -134,6 +143,7 @@ def init_test_case(self):
134143
self.pad = [0, 0, 0]
135144
self.stride = [1, 1, 1]
136145
self.dilations = [1, 1, 1]
146+
self.groups = 1
137147
self.input_size = [2, 3, 5, 5, 5] # NCDHW
138148
f_c = self.input_size[1]
139149
self.filter_size = [f_c, 6, 3, 3, 3]
@@ -147,16 +157,29 @@ def init_test_case(self):
147157
self.pad = [1, 1, 1]
148158
self.stride = [1, 1, 1]
149159
self.dilations = [1, 1, 1]
160+
self.groups = 1
150161
self.input_size = [2, 3, 5, 5, 5] # NCDHW
151162
f_c = self.input_size[1]
152163
self.filter_size = [f_c, 6, 3, 3, 3]
153164

154165

166+
class TestWithGroups(TestConv3dTransposeOp):
167+
def init_test_case(self):
168+
self.pad = [1, 1, 1]
169+
self.stride = [1, 1, 1]
170+
self.dilations = [1, 1, 1]
171+
self.groups = 2
172+
self.input_size = [2, 4, 5, 5, 5] # NCHW
173+
f_c = self.input_size[1]
174+
self.filter_size = [f_c, 3, 3, 3, 3]
175+
176+
155177
class TestWithStride(TestConv3dTransposeOp):
156178
def init_test_case(self):
157179
self.pad = [1, 1, 1]
158180
self.stride = [2, 2, 2]
159181
self.dilations = [1, 1, 1]
182+
self.groups = 1
160183
self.input_size = [2, 3, 5, 5, 5] # NCDHW
161184
f_c = self.input_size[1]
162185
self.filter_size = [f_c, 6, 3, 3, 3]
@@ -167,6 +190,7 @@ def init_test_case(self):
167190
self.pad = [1, 1, 1]
168191
self.stride = [1, 1, 1]
169192
self.dilations = [2, 2, 2]
193+
self.groups = 1
170194
self.input_size = [2, 3, 5, 5, 5] # NCDHW
171195
f_c = self.input_size[1]
172196
self.filter_size = [f_c, 6, 3, 3, 3]
@@ -184,6 +208,7 @@ def init_test_case(self):
184208
self.pad = [1, 1, 1]
185209
self.stride = [1, 1, 1]
186210
self.dilations = [1, 1, 1]
211+
self.groups = 1
187212
self.input_size = [2, 3, 5, 5, 5] # NCDHW
188213
f_c = self.input_size[1]
189214
self.filter_size = [f_c, 6, 3, 3, 3]
@@ -198,6 +223,7 @@ def init_test_case(self):
198223
self.pad = [1, 1, 1]
199224
self.stride = [2, 2, 2]
200225
self.dilations = [1, 1, 1]
226+
self.groups = 1
201227
self.input_size = [2, 3, 5, 5, 5] # NCDHW
202228
f_c = self.input_size[1]
203229
self.filter_size = [f_c, 6, 3, 3, 3]
@@ -207,6 +233,21 @@ def init_op_type(self):
207233
self.op_type = "conv3d_transpose"
208234

209235

236+
class TestCUDNNWithGroups(TestWithGroups):
237+
def init_test_case(self):
238+
self.pad = [1, 1, 1]
239+
self.stride = [1, 1, 1]
240+
self.dilations = [1, 1, 1]
241+
self.groups = 2
242+
self.input_size = [2, 4, 5, 5, 5] # NCHW
243+
f_c = self.input_size[1]
244+
self.filter_size = [f_c, 3, 3, 3, 3]
245+
246+
def init_op_type(self):
247+
self.use_cudnn = True
248+
self.op_type = "conv3d_transpose"
249+
250+
210251
# Please Don't remove the following code.
211252
# Currently, CI use cudnn V5.0 which not support dilation conv.
212253
# class TestCUDNNWithDilation(TestWithDilation):

0 commit comments

Comments
 (0)