@@ -395,6 +395,77 @@ def test_conv2d_6(self):
395
395
kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
396
396
self ._conv_test (x_val , kernel_val , strides = strides , padding = "VALID" , rtol = 1e-05 )
397
397
398
+ def test_conv3d_1 (self ):
399
+ strides = [1 , 1 , 1 , 1 , 1 ]
400
+ dilations = [1 , 1 , 1 , 1 , 1 ]
401
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
402
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
403
+ padding = "VALID"
404
+ def func (x ):
405
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
406
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
407
+ return tf .identity (conv , name = _TFOUTPUT )
408
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
409
+
410
+ def test_conv3d_2 (self ):
411
+ strides = [1 , 2 , 3 , 1 , 1 ]
412
+ dilations = [1 , 1 , 1 , 1 , 1 ]
413
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
414
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
415
+ padding = "VALID"
416
+ def func (x ):
417
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
418
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
419
+ return tf .identity (conv , name = _TFOUTPUT )
420
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
421
+
422
+ def test_conv3d_3 (self ):
423
+ strides = [1 , 2 , 3 , 1 , 1 ]
424
+ dilations = [1 , 1 , 1 , 1 , 1 ]
425
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
426
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
427
+ padding = "SAME"
428
+ def func (x ):
429
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
430
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
431
+ return tf .identity (conv , name = _TFOUTPUT )
432
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
433
+
434
+ def test_avgpool3d (self ):
435
+ strides = [1 , 1 , 1 , 1 , 1 ]
436
+ ksize = [1 , 2 , 2 , 3 , 1 ]
437
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
438
+ padding = "VALID"
439
+
440
+ def func (x ):
441
+ mp = tf .nn .avg_pool3d (x , ksize , strides , padding = padding , data_format = "NDHWC" )
442
+ return tf .identity (mp , name = _TFOUTPUT )
443
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
444
+
445
+ def test_maxpool3d (self ):
446
+ strides = [1 , 1 , 1 , 1 , 1 ]
447
+ ksize = [1 , 2 , 2 , 3 , 1 ]
448
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
449
+ padding = "VALID"
450
+
451
+ def func (x ):
452
+ mp = tf .nn .max_pool3d (x , ksize , strides , padding = padding , data_format = "NDHWC" )
453
+ return tf .identity (mp , name = _TFOUTPUT )
454
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
455
+
456
+ @check_tf_min_version ("1.14" , "tf.nn.avg_pool2d doesn't exist before tf 1.14" )
457
+ def test_avgpool2d (self ):
458
+ strides = [1 , 1 , 1 , 1 ]
459
+ ksize = [1 , 2 , 3 , 1 ]
460
+ x_val = make_xval ([2 , 10 , 12 , 3 ])
461
+ padding = "VALID"
462
+
463
+ def func (x ):
464
+ mp = tf .nn .avg_pool2d (x , ksize , strides , padding = padding , data_format = "NHWC" )
465
+ return tf .identity (mp , name = _TFOUTPUT )
466
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
467
+
468
+
398
469
@check_tf_min_version ("1.7" , "tf only support dilation is 1 for now" )
399
470
def test_conv2d_7 (self ):
400
471
x_shape = [1 , 35 , 35 , 288 ] # out: [1, 17, 17, 384]
0 commit comments