@@ -225,29 +225,29 @@ def init_group(self):
225
225
#----------------Conv2dCUDNN----------------
226
226
227
227
228
- def create_test_cudnn_class (parent , cls_name ):
228
+ def create_test_cudnn_class (parent ):
229
229
@unittest .skipIf (not core .is_compiled_with_cuda (),
230
230
"core is not compiled with CUDA" )
231
231
class TestCUDNNCase (parent ):
232
232
def init_kernel_type (self ):
233
233
self .use_cudnn = True
234
234
235
- cls_name = "{0}" .format (cls_name )
235
+ cls_name = "{0}_{1} " .format (parent . __name__ , "CUDNN" )
236
236
TestCUDNNCase .__name__ = cls_name
237
237
globals ()[cls_name ] = TestCUDNNCase
238
238
239
239
240
- create_test_cudnn_class (TestConv2dOp , "TestPool2DCUDNNOp" )
241
- create_test_cudnn_class (TestWithPad , "TestPool2DCUDNNOpCase1" )
242
- create_test_cudnn_class (TestWithStride , "TestPool2DCUDNNOpCase2" )
243
- create_test_cudnn_class (TestWithGroup , "TestPool2DCUDNNOpCase3" )
244
- create_test_cudnn_class (TestWith1x1 , "TestPool2DCUDNNOpCase4" )
245
- create_test_cudnn_class (TestWithInput1x1Filter1x1 , "TestPool2DCUDNNOpCase4" )
240
+ create_test_cudnn_class (TestConv2dOp )
241
+ create_test_cudnn_class (TestWithPad )
242
+ create_test_cudnn_class (TestWithStride )
243
+ create_test_cudnn_class (TestWithGroup )
244
+ create_test_cudnn_class (TestWith1x1 )
245
+ create_test_cudnn_class (TestWithInput1x1Filter1x1 )
246
246
247
247
#----------------Conv2dCUDNN----------------
248
248
249
249
250
- def create_test_cudnn_fp16_class (parent , cls_name , grad_check = True ):
250
+ def create_test_cudnn_fp16_class (parent , grad_check = True ):
251
251
@unittest .skipIf (not core .is_compiled_with_cuda (),
252
252
"core is not compiled with CUDA" )
253
253
class TestConv2DCUDNNFp16 (parent ):
@@ -279,23 +279,17 @@ def test_check_grad_no_input(self):
279
279
max_relative_error = 0.02 ,
280
280
no_grad_set = set (['Input' ]))
281
281
282
- cls_name = "{0}" .format (cls_name )
282
+ cls_name = "{0}_{1} " .format (parent . __name__ , "CUDNNFp16" )
283
283
TestConv2DCUDNNFp16 .__name__ = cls_name
284
284
globals ()[cls_name ] = TestConv2DCUDNNFp16
285
285
286
286
287
- create_test_cudnn_fp16_class (
288
- TestConv2dOp , "TestPool2DCUDNNFp16Op" , grad_check = False )
289
- create_test_cudnn_fp16_class (
290
- TestWithPad , "TestPool2DCUDNNFp16OpCase1" , grad_check = False )
291
- create_test_cudnn_fp16_class (
292
- TestWithStride , "TestPool2DCUDNNFp16OpCase2" , grad_check = False )
293
- create_test_cudnn_fp16_class (
294
- TestWithGroup , "TestPool2DCUDNNFp16OpCase3" , grad_check = False )
295
- create_test_cudnn_fp16_class (
296
- TestWith1x1 , "TestPool2DCUDNNFp16OpCase4" , grad_check = False )
297
- create_test_cudnn_fp16_class (
298
- TestWithInput1x1Filter1x1 , "TestPool2DCUDNNFp16OpCase4" , grad_check = False )
287
+ create_test_cudnn_fp16_class (TestConv2dOp , grad_check = False )
288
+ create_test_cudnn_fp16_class (TestWithPad , grad_check = False )
289
+ create_test_cudnn_fp16_class (TestWithStride , grad_check = False )
290
+ create_test_cudnn_fp16_class (TestWithGroup , grad_check = False )
291
+ create_test_cudnn_fp16_class (TestWith1x1 , grad_check = False )
292
+ create_test_cudnn_fp16_class (TestWithInput1x1Filter1x1 , grad_check = False )
299
293
300
294
# -------TestDepthwiseConv
301
295
0 commit comments