Skip to content

Commit 79a99f4

Browse files
leimaoTomWildenhain-Microsoftguschmue
authored
Fix TF1-Keras Dilated Conv Export (#1744)
update patch skip cpu test for conv3d ncdhw skip CPU for Conv2D NCHW update Signed-off-by: Lei Mao <[email protected]> Remove NVIDIA license header Co-authored-by: TomWildenhain-Microsoft <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent bfb2033 commit 79a99f4

File tree

2 files changed

+85
-13
lines changed

2 files changed

+85
-13
lines changed

tests/test_backend.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def func(x):
727727
onnx_feed_dict={_INPUT: x_val_for_onnx})
728728

729729
@skip_tflite("TFlite adds ops that obscure pattern")
730-
@check_tf_min_version("2.0")
730+
@check_tf_min_version("1.15")
731731
def test_conv1d_dilations_rewriter(self):
732732
x_shape = [2, 32, 3]
733733
x_val = make_xval(x_shape)
@@ -740,7 +740,7 @@ def func(x):
740740
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
741741
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
742742

743-
@check_tf_min_version("2.0")
743+
@check_tf_min_version("1.15")
744744
def test_conv2d_dilations_rewriter(self):
745745
x_shape = [2, 32, 16, 3]
746746
x_val = make_xval(x_shape)
@@ -760,7 +760,39 @@ def func(x):
760760
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
761761
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
762762

763-
@check_tf_min_version("2.0")
763+
@check_tf_min_version("1.15")
764+
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
765+
def test_nchw_conv2d_dilations_rewriter(self):
766+
x_shape = [2, 3, 32, 16]
767+
x_val = make_xval(x_shape)
768+
for p in ['SAME', 'VALID']:
769+
def func(x):
770+
t = tf.keras.layers.Conv2D(
771+
filters=768,
772+
kernel_size=3,
773+
dilation_rate=3,
774+
padding=p,
775+
data_format='channels_first'
776+
)
777+
t.build(x_shape)
778+
y = t.call(x)
779+
return tf.identity(y, name=_TFOUTPUT)
780+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
781+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
782+
def func(x):
783+
t = tf.keras.layers.DepthwiseConv2D(
784+
kernel_size=3,
785+
dilation_rate=3,
786+
padding=p,
787+
data_format='channels_first'
788+
)
789+
t.build(x_shape)
790+
y = t.call(x)
791+
return tf.identity(y, name=_TFOUTPUT)
792+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
793+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
794+
795+
@check_tf_min_version("1.15")
764796
@skip_tflite("TFlite adds ops that obscure pattern")
765797
@allow_missing_shapes("Rewriting makes some shapes known")
766798
def test_conv2d_dilations_rewriter_unknown_shape(self):
@@ -776,7 +808,30 @@ def func():
776808
as_session=True, premade_placeholders=True,
777809
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
778810

779-
@check_tf_min_version("2.0")
811+
@check_tf_min_version("1.15")
812+
@skip_tflite("TFlite adds ops that obscure pattern")
813+
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
814+
@allow_missing_shapes("Rewriting makes some shapes known")
815+
def test_nchw_conv2d_dilations_rewriter_unknown_shape(self):
816+
x_shape = [2, 3, 32, 16]
817+
x_val = make_xval(x_shape)
818+
def func():
819+
x = tf_placeholder(tf.float32, [2, 3, None, None], name=_TFINPUT)
820+
t = tf.keras.layers.Conv2D(
821+
filters=768,
822+
kernel_size=3,
823+
dilation_rate=3,
824+
padding="VALID",
825+
data_format='channels_first'
826+
)
827+
t.build(x_shape)
828+
y = t.call(x)
829+
return tf.identity(y, name=_TFOUTPUT)
830+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2,
831+
as_session=True, premade_placeholders=True,
832+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
833+
834+
@check_tf_min_version("1.15")
780835
def test_conv3d_dilations_rewriter(self):
781836
x_shape = [2, 32, 16, 8, 3]
782837
x_val = make_xval(x_shape)
@@ -789,6 +844,26 @@ def func(x):
789844
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
790845
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
791846

847+
@check_tf_min_version("1.15")
848+
@skip_tf_cpu("only tf_gpu can run conv3d with NCDHW format")
849+
def test_ncdhw_conv3d_dilations_rewriter(self):
850+
x_shape = [2, 3, 32, 16, 8]
851+
x_val = make_xval(x_shape)
852+
for p in ['SAME', 'VALID']:
853+
def func(x):
854+
t = tf.keras.layers.Conv3D(
855+
filters=768,
856+
kernel_size=3,
857+
dilation_rate=3,
858+
padding=p,
859+
data_format='channels_first'
860+
)
861+
t.build(x_shape)
862+
y = t.call(x)
863+
return tf.identity(y, name=_TFOUTPUT)
864+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
865+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
866+
792867
@skip_tf2("Uses tf.layers")
793868
def test_conv1d_tf1_dilations_rewriter(self):
794869
x_shape = [2, 32, 3]

tf2onnx/rewriter/conv_dilations_rewriter.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
def rewrite_conv_dilations(g, ops):
16+
1617
pattern1 = \
1718
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
1819
OpTypePattern("DepthwiseConv2dNative|Conv2D|Conv3D", name="conv", inputs=[
@@ -67,14 +68,7 @@ def rewrite_conv_dilations(g, ops):
6768
if block_shape1 != block_shape2:
6869
continue
6970
ndims = 2 if is_conv_1d else len(block_shape1)
70-
data_format = b"NHWC" if ndims == 2 else b"NDHWC"
71-
ones = [1] * (ndims + 2)
72-
if conv.get_attr_value("dilations", ones) != ones:
73-
continue
74-
if conv.get_attr_value("strides", ones) != ones:
75-
continue
76-
if conv.get_attr_value("data_format", data_format) != data_format:
77-
continue
71+
7872
if conv.get_attr_value("padding") != b"VALID":
7973
continue
8074

@@ -114,7 +108,10 @@ def rewrite_conv_dilations(g, ops):
114108
g.copy_shape(batch_to_space.output[0], conv.output[0])
115109
g.replace_all_inputs(batch_to_space.output[0], conv.output[0])
116110

117-
conv.set_attr("dilations", [1] + block_shape1 + [1])
111+
if conv.get_attr_value("data_format") in [b"NCHW", b"NCDHW"]:
112+
conv.set_attr("dilations", [1] + block_shape1)
113+
else:
114+
conv.set_attr("dilations", [1] + block_shape1 + [1])
118115
conv.set_attr("padding", pad_mode)
119116
if pad_mode == "EXPLICIT":
120117
conv.set_attr("explicit_paddings", base_pad_flat)

0 commit comments

Comments
 (0)