Skip to content

Commit f8b2680

Browse files
author
chengduo
authored
fix test_conv2d (#14330)
test=develop
1 parent c5b6573 commit f8b2680

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -225,29 +225,29 @@ def init_group(self):
225225
#----------------Conv2dCUDNN----------------
226226

227227

228-
def create_test_cudnn_class(parent, cls_name):
228+
def create_test_cudnn_class(parent):
229229
@unittest.skipIf(not core.is_compiled_with_cuda(),
230230
"core is not compiled with CUDA")
231231
class TestCUDNNCase(parent):
232232
def init_kernel_type(self):
233233
self.use_cudnn = True
234234

235-
cls_name = "{0}".format(cls_name)
235+
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
236236
TestCUDNNCase.__name__ = cls_name
237237
globals()[cls_name] = TestCUDNNCase
238238

239239

240-
create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp")
241-
create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1")
242-
create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2")
243-
create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3")
244-
create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4")
245-
create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4")
240+
create_test_cudnn_class(TestConv2dOp)
241+
create_test_cudnn_class(TestWithPad)
242+
create_test_cudnn_class(TestWithStride)
243+
create_test_cudnn_class(TestWithGroup)
244+
create_test_cudnn_class(TestWith1x1)
245+
create_test_cudnn_class(TestWithInput1x1Filter1x1)
246246

247247
#----------------Conv2dCUDNN----------------
248248

249249

250-
def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
250+
def create_test_cudnn_fp16_class(parent, grad_check=True):
251251
@unittest.skipIf(not core.is_compiled_with_cuda(),
252252
"core is not compiled with CUDA")
253253
class TestConv2DCUDNNFp16(parent):
@@ -279,23 +279,17 @@ def test_check_grad_no_input(self):
279279
max_relative_error=0.02,
280280
no_grad_set=set(['Input']))
281281

282-
cls_name = "{0}".format(cls_name)
282+
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
283283
TestConv2DCUDNNFp16.__name__ = cls_name
284284
globals()[cls_name] = TestConv2DCUDNNFp16
285285

286286

287-
create_test_cudnn_fp16_class(
288-
TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False)
289-
create_test_cudnn_fp16_class(
290-
TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False)
291-
create_test_cudnn_fp16_class(
292-
TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False)
293-
create_test_cudnn_fp16_class(
294-
TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False)
295-
create_test_cudnn_fp16_class(
296-
TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
297-
create_test_cudnn_fp16_class(
298-
TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
287+
create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False)
288+
create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
289+
create_test_cudnn_fp16_class(TestWithStride, grad_check=False)
290+
create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
291+
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
292+
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
299293

300294
# -------TestDepthwiseConv
301295

0 commit comments

Comments
 (0)