21
21
22
22
def conv3dtranspose_forward_naive (input_ , filter_ , attrs ):
23
23
in_n , in_c , in_d , in_h , in_w = input_ .shape
24
- f_c , out_c , f_d , f_h , f_w = filter_ .shape
24
+ f_c , f_out_c , f_d , f_h , f_w = filter_ .shape
25
+ groups = attrs ['groups' ]
25
26
assert in_c == f_c
27
+ out_c = f_out_c * groups
28
+ sub_in_c = in_c / groups
26
29
27
30
stride , pad , dilations = attrs ['strides' ], attrs ['paddings' ], attrs [
28
31
'dilations' ]
@@ -39,18 +42,23 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
39
42
for d in range (in_d ):
40
43
for i in range (in_h ):
41
44
for j in range (in_w ):
42
- input_masked = input_ [n , :, d , i , j ] # (c)
43
- input_masked = np .reshape (input_masked , (in_c , 1 , 1 , 1 ))
44
- input_masked = np .tile (input_masked , (1 , f_d , f_h , f_w ))
45
-
46
- for k in range (out_c ):
47
- tmp_out = np .sum (input_masked * filter_ [:, k , :, :, :],
48
- axis = 0 )
49
- d1 , d2 = d * stride [0 ], d * stride [0 ] + d_bolck_d
50
- i1 , i2 = i * stride [1 ], i * stride [1 ] + d_bolck_h
51
- j1 , j2 = j * stride [2 ], j * stride [2 ] + d_bolck_w
52
- out [n , k , d1 :d2 :dilations [0 ], i1 :i2 :dilations [1 ], j1 :j2 :
53
- dilations [2 ]] += tmp_out
45
+ for g in range (groups ):
46
+ input_masked = input_ [n , g * sub_in_c :(g + 1
47
+ ) * sub_in_c , d ,
48
+ i , j ] # (c)
49
+ input_masked = np .reshape (input_masked ,
50
+ (sub_in_c , 1 , 1 , 1 ))
51
+ input_masked = np .tile (input_masked , (1 , f_d , f_h , f_w ))
52
+
53
+ for k in range (f_out_c ):
54
+ tmp_out = np .sum (input_masked * filter_ [
55
+ g * sub_in_c :(g + 1 ) * sub_in_c , k , :, :, :],
56
+ axis = 0 )
57
+ d1 , d2 = d * stride [0 ], d * stride [0 ] + d_bolck_d
58
+ i1 , i2 = i * stride [1 ], i * stride [1 ] + d_bolck_h
59
+ j1 , j2 = j * stride [2 ], j * stride [2 ] + d_bolck_w
60
+ out [n , g * f_out_c + k , d1 :d2 :dilations [0 ], i1 :i2 :
61
+ dilations [1 ], j1 :j2 :dilations [2 ]] += tmp_out
54
62
55
63
out = out [:, :, pad [0 ]:out_d - pad [0 ], pad [1 ]:out_h - pad [1 ], pad [2 ]:out_w -
56
64
pad [2 ]]
@@ -72,6 +80,7 @@ def setUp(self):
72
80
'strides' : self .stride ,
73
81
'paddings' : self .pad ,
74
82
'dilations' : self .dilations ,
83
+ 'groups' : self .groups ,
75
84
'use_cudnn' : self .use_cudnn ,
76
85
'data_format' : 'AnyLayout' # TODO(dzhwinter) : should be fix latter
77
86
}
@@ -134,6 +143,7 @@ def init_test_case(self):
134
143
self .pad = [0 , 0 , 0 ]
135
144
self .stride = [1 , 1 , 1 ]
136
145
self .dilations = [1 , 1 , 1 ]
146
+ self .groups = 1
137
147
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
138
148
f_c = self .input_size [1 ]
139
149
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
@@ -147,16 +157,29 @@ def init_test_case(self):
147
157
self .pad = [1 , 1 , 1 ]
148
158
self .stride = [1 , 1 , 1 ]
149
159
self .dilations = [1 , 1 , 1 ]
160
+ self .groups = 1
150
161
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
151
162
f_c = self .input_size [1 ]
152
163
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
153
164
154
165
166
+ class TestWithGroups (TestConv3dTransposeOp ):
167
+ def init_test_case (self ):
168
+ self .pad = [1 , 1 , 1 ]
169
+ self .stride = [1 , 1 , 1 ]
170
+ self .dilations = [1 , 1 , 1 ]
171
+ self .groups = 2
172
+ self .input_size = [2 , 4 , 5 , 5 , 5 ] # NCHW
173
+ f_c = self .input_size [1 ]
174
+ self .filter_size = [f_c , 3 , 3 , 3 , 3 ]
175
+
176
+
155
177
class TestWithStride (TestConv3dTransposeOp ):
156
178
def init_test_case (self ):
157
179
self .pad = [1 , 1 , 1 ]
158
180
self .stride = [2 , 2 , 2 ]
159
181
self .dilations = [1 , 1 , 1 ]
182
+ self .groups = 1
160
183
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
161
184
f_c = self .input_size [1 ]
162
185
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
@@ -167,6 +190,7 @@ def init_test_case(self):
167
190
self .pad = [1 , 1 , 1 ]
168
191
self .stride = [1 , 1 , 1 ]
169
192
self .dilations = [2 , 2 , 2 ]
193
+ self .groups = 1
170
194
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
171
195
f_c = self .input_size [1 ]
172
196
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
@@ -184,6 +208,7 @@ def init_test_case(self):
184
208
self .pad = [1 , 1 , 1 ]
185
209
self .stride = [1 , 1 , 1 ]
186
210
self .dilations = [1 , 1 , 1 ]
211
+ self .groups = 1
187
212
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
188
213
f_c = self .input_size [1 ]
189
214
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
@@ -198,6 +223,7 @@ def init_test_case(self):
198
223
self .pad = [1 , 1 , 1 ]
199
224
self .stride = [2 , 2 , 2 ]
200
225
self .dilations = [1 , 1 , 1 ]
226
+ self .groups = 1
201
227
self .input_size = [2 , 3 , 5 , 5 , 5 ] # NCDHW
202
228
f_c = self .input_size [1 ]
203
229
self .filter_size = [f_c , 6 , 3 , 3 , 3 ]
@@ -207,6 +233,21 @@ def init_op_type(self):
207
233
self .op_type = "conv3d_transpose"
208
234
209
235
236
+ class TestCUDNNWithGroups (TestWithGroups ):
237
+ def init_test_case (self ):
238
+ self .pad = [1 , 1 , 1 ]
239
+ self .stride = [1 , 1 , 1 ]
240
+ self .dilations = [1 , 1 , 1 ]
241
+ self .groups = 2
242
+ self .input_size = [2 , 4 , 5 , 5 , 5 ] # NCHW
243
+ f_c = self .input_size [1 ]
244
+ self .filter_size = [f_c , 3 , 3 , 3 , 3 ]
245
+
246
+ def init_op_type (self ):
247
+ self .use_cudnn = True
248
+ self .op_type = "conv3d_transpose"
249
+
250
+
210
251
# Please Don't remove the following code.
211
252
# Currently, CI use cudnn V5.0 which not support dilation conv.
212
253
# class TestCUDNNWithDilation(TestWithDilation):
0 commit comments