@@ -727,7 +727,7 @@ def func(x):
727
727
onnx_feed_dict = {_INPUT : x_val_for_onnx })
728
728
729
729
@skip_tflite ("TFlite adds ops that obscure pattern" )
730
- @check_tf_min_version ("2.0 " )
730
+ @check_tf_min_version ("1.15 " )
731
731
def test_conv1d_dilations_rewriter (self ):
732
732
x_shape = [2 , 32 , 3 ]
733
733
x_val = make_xval (x_shape )
@@ -740,7 +740,7 @@ def func(x):
740
740
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
741
741
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
742
742
743
- @check_tf_min_version ("2.0 " )
743
+ @check_tf_min_version ("1.15 " )
744
744
def test_conv2d_dilations_rewriter (self ):
745
745
x_shape = [2 , 32 , 16 , 3 ]
746
746
x_val = make_xval (x_shape )
@@ -760,7 +760,39 @@ def func(x):
760
760
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
761
761
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
762
762
763
- @check_tf_min_version ("2.0" )
763
+ @check_tf_min_version ("1.15" )
764
+ @skip_tf_cpu ("only tf_gpu can run conv2d with NCHW format" )
765
+ def test_nchw_conv2d_dilations_rewriter (self ):
766
+ x_shape = [2 , 3 , 32 , 16 ]
767
+ x_val = make_xval (x_shape )
768
+ for p in ['SAME' , 'VALID' ]:
769
+ def func (x ):
770
+ t = tf .keras .layers .Conv2D (
771
+ filters = 768 ,
772
+ kernel_size = 3 ,
773
+ dilation_rate = 3 ,
774
+ padding = p ,
775
+ data_format = 'channels_first'
776
+ )
777
+ t .build (x_shape )
778
+ y = t .call (x )
779
+ return tf .identity (y , name = _TFOUTPUT )
780
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
781
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
782
+ def func (x ):
783
+ t = tf .keras .layers .DepthwiseConv2D (
784
+ kernel_size = 3 ,
785
+ dilation_rate = 3 ,
786
+ padding = p ,
787
+ data_format = 'channels_first'
788
+ )
789
+ t .build (x_shape )
790
+ y = t .call (x )
791
+ return tf .identity (y , name = _TFOUTPUT )
792
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
793
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
794
+
795
+ @check_tf_min_version ("1.15" )
764
796
@skip_tflite ("TFlite adds ops that obscure pattern" )
765
797
@allow_missing_shapes ("Rewriting makes some shapes known" )
766
798
def test_conv2d_dilations_rewriter_unknown_shape (self ):
@@ -776,7 +808,30 @@ def func():
776
808
as_session = True , premade_placeholders = True ,
777
809
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
778
810
779
- @check_tf_min_version ("2.0" )
811
+ @check_tf_min_version ("1.15" )
812
+ @skip_tflite ("TFlite adds ops that obscure pattern" )
813
+ @skip_tf_cpu ("only tf_gpu can run conv2d with NCHW format" )
814
+ @allow_missing_shapes ("Rewriting makes some shapes known" )
815
+ def test_nchw_conv2d_dilations_rewriter_unknown_shape (self ):
816
+ x_shape = [2 , 3 , 32 , 16 ]
817
+ x_val = make_xval (x_shape )
818
+ def func ():
819
+ x = tf_placeholder (tf .float32 , [2 , 3 , None , None ], name = _TFINPUT )
820
+ t = tf .keras .layers .Conv2D (
821
+ filters = 768 ,
822
+ kernel_size = 3 ,
823
+ dilation_rate = 3 ,
824
+ padding = "VALID" ,
825
+ data_format = 'channels_first'
826
+ )
827
+ t .build (x_shape )
828
+ y = t .call (x )
829
+ return tf .identity (y , name = _TFOUTPUT )
830
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 ,
831
+ as_session = True , premade_placeholders = True ,
832
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
833
+
834
+ @check_tf_min_version ("1.15" )
780
835
def test_conv3d_dilations_rewriter (self ):
781
836
x_shape = [2 , 32 , 16 , 8 , 3 ]
782
837
x_val = make_xval (x_shape )
@@ -789,6 +844,26 @@ def func(x):
789
844
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
790
845
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
791
846
847
+ @check_tf_min_version ("1.15" )
848
+ @skip_tf_cpu ("only tf_gpu can run conv3d with NCDHW format" )
849
+ def test_ncdhw_conv3d_dilations_rewriter (self ):
850
+ x_shape = [2 , 3 , 32 , 16 , 8 ]
851
+ x_val = make_xval (x_shape )
852
+ for p in ['SAME' , 'VALID' ]:
853
+ def func (x ):
854
+ t = tf .keras .layers .Conv3D (
855
+ filters = 768 ,
856
+ kernel_size = 3 ,
857
+ dilation_rate = 3 ,
858
+ padding = p ,
859
+ data_format = 'channels_first'
860
+ )
861
+ t .build (x_shape )
862
+ y = t .call (x )
863
+ return tf .identity (y , name = _TFOUTPUT )
864
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
865
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
866
+
792
867
@skip_tf2 ("Uses tf.layers" )
793
868
def test_conv1d_tf1_dilations_rewriter (self ):
794
869
x_shape = [2 , 32 , 3 ]
0 commit comments