@@ -392,6 +392,77 @@ def test_conv2d_6(self):
392
392
kernel_val = np .arange (1 , 1 + np .prod (kernel_shape )).astype ("float32" ).reshape (kernel_shape )
393
393
self ._conv_test (x_val , kernel_val , strides = strides , padding = "VALID" , rtol = 1e-05 )
394
394
395
+ def test_conv3d_1 (self ):
396
+ strides = [1 , 1 , 1 , 1 , 1 ]
397
+ dilations = [1 , 1 , 1 , 1 , 1 ]
398
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
399
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
400
+ padding = "VALID"
401
+ def func (x ):
402
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
403
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
404
+ return tf .identity (conv , name = _TFOUTPUT )
405
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
406
+
407
+ def test_conv3d_2 (self ):
408
+ strides = [1 , 2 , 3 , 1 , 1 ]
409
+ dilations = [1 , 1 , 1 , 1 , 1 ]
410
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
411
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
412
+ padding = "VALID"
413
+ def func (x ):
414
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
415
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
416
+ return tf .identity (conv , name = _TFOUTPUT )
417
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
418
+
419
+ def test_conv3d_3 (self ):
420
+ strides = [1 , 2 , 3 , 1 , 1 ]
421
+ dilations = [1 , 1 , 1 , 1 , 1 ]
422
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
423
+ w = np .random .random_sample ([2 , 3 , 4 , 5 , 6 ]).astype (np .float32 )
424
+ padding = "SAME"
425
+ def func (x ):
426
+ kernel = tf .constant (w , dtype = tf .float32 , name = 'k' )
427
+ conv = tf .nn .conv3d (x , kernel , strides = strides , padding = padding , data_format = "NDHWC" , dilations = dilations )
428
+ return tf .identity (conv , name = _TFOUTPUT )
429
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-05 )
430
+
431
+ def test_avgpool3d (self ):
432
+ strides = [1 , 1 , 1 , 1 , 1 ]
433
+ ksize = [1 , 2 , 2 , 3 , 1 ]
434
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
435
+ padding = "VALID"
436
+
437
+ def func (x ):
438
+ mp = tf .nn .avg_pool3d (x , ksize , strides , padding = padding , data_format = "NDHWC" )
439
+ return tf .identity (mp , name = _TFOUTPUT )
440
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
441
+
442
+ def test_maxpool3d (self ):
443
+ strides = [1 , 1 , 1 , 1 , 1 ]
444
+ ksize = [1 , 2 , 2 , 3 , 1 ]
445
+ x_val = np .random .random_sample ([2 , 10 , 9 , 8 , 5 ]).astype (np .float32 )
446
+ padding = "VALID"
447
+
448
+ def func (x ):
449
+ mp = tf .nn .max_pool3d (x , ksize , strides , padding = padding , data_format = "NDHWC" )
450
+ return tf .identity (mp , name = _TFOUTPUT )
451
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
452
+
453
+ @check_tf_min_version ("1.14" , "tf.nn.avg_pool2d doesn't exist before tf 1.14" )
454
+ def test_avgpool2d (self ):
455
+ strides = [1 , 1 , 1 , 1 ]
456
+ ksize = [1 , 2 , 3 , 1 ]
457
+ x_val = make_xval ([2 , 10 , 12 , 3 ])
458
+ padding = "VALID"
459
+
460
+ def func (x ):
461
+ mp = tf .nn .avg_pool2d (x , ksize , strides , padding = padding , data_format = "NHWC" )
462
+ return tf .identity (mp , name = _TFOUTPUT )
463
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
464
+
465
+
395
466
@check_tf_min_version ("1.7" , "tf only support dilation is 1 for now" )
396
467
def test_conv2d_7 (self ):
397
468
x_shape = [1 , 35 , 35 , 288 ] # out: [1, 17, 17, 384]
0 commit comments