Skip to content

Commit 8b16927

Browse files
committed
add fp16 support to conv3d
1 parent fd1971c commit 8b16927

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
366366

367367
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
368368
paddle::operators::CUDNNConvOpKernel<float>,
369-
paddle::operators::CUDNNConvOpKernel<double>);
369+
paddle::operators::CUDNNConvOpKernel<double>,
370+
paddle::operators::CUDNNConvOpKernel<plat::float16>);
370371
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
371372
paddle::operators::CUDNNConvGradOpKernel<float>,
372373
paddle::operators::CUDNNConvGradOpKernel<double>);

python/paddle/fluid/tests/unittests/test_conv3d_op.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,30 +70,36 @@ def conv3d_forward_naive(input, filter, group, conv_param):
7070

7171
class TestConv3dOp(OpTest):
7272
def setUp(self):
73+
self.op_type = "conv3d"
7374
self.use_cudnn = False
75+
self.dtype = np.float32
76+
self.init_kernel_type()
7477
self.init_group()
75-
self.init_op_type()
7678
self.init_dilation()
7779
self.init_test_case()
7880

7981
conv3d_param = {
8082
'stride': self.stride,
8183
'pad': self.pad,
8284
'dilations': self.dilations,
83-
'use_cudnn': self.use_cudnn,
8485
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
8586
}
86-
input = np.random.random(self.input_size).astype("float32")
87-
filter = np.random.random(self.filter_size).astype("float32")
87+
88+
input = np.random.random(self.input_size).astype(self.dtype)
89+
filter = np.random.random(self.filter_size).astype(self.dtype)
8890
output = conv3d_forward_naive(input, filter, self.groups,
89-
conv3d_param).astype("float32")
91+
conv3d_param).astype(self.dtype)
9092

91-
self.inputs = {'Input': input, 'Filter': filter}
93+
self.inputs = {
94+
'Input': OpTest.np_dtype_to_fluid_dtype(input),
95+
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
96+
}
9297
self.attrs = {
9398
'strides': self.stride,
9499
'paddings': self.pad,
95100
'groups': self.groups,
96-
'dilations': self.dilations
101+
'dilations': self.dilations,
102+
'use_cudnn': self.use_cudnn
97103
}
98104
self.outputs = {'Output': output}
99105

@@ -108,6 +114,8 @@ def test_check_output(self):
108114
self.check_output()
109115

110116
def test_check_grad(self):
117+
if self.dtype == np.float16:
118+
return
111119
if self.testcudnn():
112120
place = core.CUDAPlace(0)
113121
self.check_grad_with_place(
@@ -120,6 +128,8 @@ def test_check_grad(self):
120128
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
121129

122130
def test_check_grad_no_filter(self):
131+
if self.dtype == np.float16:
132+
return
123133
if self.testcudnn():
124134
place = core.CUDAPlace(0)
125135
self.check_grad_with_place(
@@ -135,6 +145,8 @@ def test_check_grad_no_filter(self):
135145
no_grad_set=set(['Filter']))
136146

137147
def test_check_grad_no_input(self):
148+
if self.dtype == np.float16:
149+
return
138150
if self.testcudnn():
139151
place = core.CUDAPlace(0)
140152
self.check_grad_with_place(
@@ -163,8 +175,8 @@ def init_dilation(self):
163175
def init_group(self):
164176
self.groups = 1
165177

166-
def init_op_type(self):
167-
self.op_type = "conv3d"
178+
def init_kernel_type(self):
179+
pass
168180

169181

170182
class TestCase1(TestConv3dOp):
@@ -235,34 +247,90 @@ def init_group(self):
235247
self.groups = 3
236248

237249

250+
#----------------Conv3dCUDNN----------------
238251
class TestCUDNN(TestConv3dOp):
239-
def init_op_type(self):
252+
def init_kernel_type(self):
240253
self.use_cudnn = True
241-
self.op_type = "conv3d"
254+
255+
256+
class TestFP16CUDNN(TestConv3dOp):
257+
def init_kernel_type(self):
258+
self.use_cudnn = True
259+
self.dtype = np.float16
260+
261+
def test_check_output(self):
262+
if core.is_compiled_with_cuda():
263+
place = core.CUDAPlace(0)
264+
if core.is_float16_supported(place):
265+
self.check_output_with_place(place, atol=2e-2)
242266

243267

244268
class TestWithGroup1CUDNN(TestWithGroup1):
245-
def init_op_type(self):
269+
def init_kernel_type(self):
246270
self.use_cudnn = True
247-
self.op_type = "conv3d"
271+
272+
273+
class TestFP16WithGroup1CUDNN(TestWithGroup1):
274+
def init_kernel_type(self):
275+
self.use_cudnn = True
276+
self.dtype = np.float16
277+
278+
def test_check_output(self):
279+
if core.is_compiled_with_cuda():
280+
place = core.CUDAPlace(0)
281+
if core.is_float16_supported(place):
282+
self.check_output_with_place(place, atol=2e-2)
248283

249284

250285
class TestWithGroup2CUDNN(TestWithGroup2):
251-
def init_op_type(self):
286+
def init_kernel_type(self):
252287
self.use_cudnn = True
253-
self.op_type = "conv3d"
288+
289+
290+
class TestFP16WithGroup2CUDNN(TestWithGroup2):
291+
def init_kernel_type(self):
292+
self.use_cudnn = True
293+
self.dtype = np.float16
294+
295+
def test_check_output(self):
296+
if core.is_compiled_with_cuda():
297+
place = core.CUDAPlace(0)
298+
if core.is_float16_supported(place):
299+
self.check_output_with_place(place, atol=2e-2)
254300

255301

256302
class TestWith1x1CUDNN(TestWith1x1):
257-
def init_op_type(self):
303+
def init_kernel_type(self):
258304
self.use_cudnn = True
259-
self.op_type = "conv3d"
305+
306+
307+
class TestFP16With1x1CUDNN(TestWith1x1):
308+
def init_kernel_type(self):
309+
self.use_cudnn = True
310+
self.dtype = np.float16
311+
312+
def test_check_output(self):
313+
if core.is_compiled_with_cuda():
314+
place = core.CUDAPlace(0)
315+
if core.is_float16_supported(place):
316+
self.check_output_with_place(place, atol=2e-2)
260317

261318

262319
class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
263-
def init_op_type(self):
320+
def init_kernel_type(self):
264321
self.use_cudnn = True
265-
self.op_type = "conv3d"
322+
323+
324+
class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
325+
def init_kernel_type(self):
326+
self.use_cudnn = True
327+
self.dtype = np.float16
328+
329+
def test_check_output(self):
330+
if core.is_compiled_with_cuda():
331+
place = core.CUDAPlace(0)
332+
if core.is_float16_supported(place):
333+
self.check_output_with_place(place, atol=2e-2)
266334

267335

268336
# FIXME(typhoonzero): find a way to determine if

0 commit comments

Comments
 (0)