Skip to content

Commit e31d9a5

Browse files
committed
support MatrixDiag and MatrixSetDiag when they are used in tf.eye
1 parent a88c048 commit e31d9a5

File tree

4 files changed

+105
-1
lines changed

4 files changed

+105
-1
lines changed

tests/test_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,32 @@ def test_expand_dims_more_unknown_rank(self):
127127
_ = tf.identity(op, name=_TFOUTPUT)
128128
self._run_test_case([_OUTPUT], {_INPUT: x_val})
129129

130+
@check_opset_min_version(9, "ConstantOfShape")
131+
def test_eye(self):
132+
# tf.eye(tf.shape)
133+
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
134+
[tf.int32, tf.int64, tf.float32, tf.float64]):
135+
tf.reset_default_graph()
136+
x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype)
137+
x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT)
138+
y_ = tf.eye(tf.shape(x)[0], dtype=tf.float32)
139+
_ = tf.identity(y_, name=_TFOUTPUT)
140+
y1_ = tf.eye(tf.shape(x)[1], dtype=tf.int32)
141+
_ = tf.identity(y1_, name=_TFOUTPUT1)
142+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0)
143+
144+
# tf.eye(tf.shape, tf.shape)
145+
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
146+
[tf.int32, tf.int64, tf.float32, tf.float64]):
147+
tf.reset_default_graph()
148+
x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype)
149+
x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT)
150+
y_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.float32)
151+
_ = tf.identity(y_, name=_TFOUTPUT)
152+
y1_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.int32)
153+
_ = tf.identity(y1_, name=_TFOUTPUT1)
154+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0)
155+
130156
@check_opset_min_version(7, "trig")
131157
def test_trig_ops(self):
132158
for op in [tf.sin, tf.cos, tf.tan, tf.asin, tf.acos, tf.atan]:

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12+
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
1213
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
1314
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1415
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
@@ -20,6 +21,7 @@
2021
"rewrite_random_uniform_fold_const",
2122
"rewrite_leakyrelu",
2223
"rewrite_thresholded_relu",
24+
"rewrite_eye",
2325
"rewrite_single_direction_lstm",
2426
"rewrite_bi_direction_lstm",
2527
"rewrite_single_direction_gru",

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter.eye_rewriter - supports tf.eye
6+
"""
7+
8+
from onnx import onnx_pb
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
12+
13+
14+
def rewrite_eye(g, ops):
15+
# tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while
16+
# these two ops are un-supported directly in onnx
17+
# but onnx op EyeLike can be used to map the sub_graph
18+
# "rewrite_eye" supports tf.eye(tf.shape(x)[i]) and tf.eye(tf.shape(x)[i], tf.shape(x)[j]).
19+
20+
# ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing
21+
if g.opset < 9:
22+
return g.get_nodes()
23+
24+
pattern1 = \
25+
OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
26+
OpTypePattern("Fill", inputs=[
27+
OpTypePattern("Const"),
28+
OpTypePattern("ConcatV2", inputs=[
29+
"*",
30+
"*",
31+
OpTypePattern("Pack", inputs=[
32+
OpTypePattern("Minimum", name="min_node")
33+
])
34+
])
35+
])
36+
])
37+
pattern2 = \
38+
OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
39+
OpTypePattern("Fill"),
40+
OpTypePattern("Fill", inputs=[
41+
OpTypePattern("Const"),
42+
OpTypePattern("ConcatV2", inputs=[
43+
"*",
44+
"*",
45+
OpTypePattern("Pack", inputs=[
46+
OpTypePattern("Minimum", name="min_node")
47+
])
48+
])
49+
])
50+
])
51+
52+
for pattern in [pattern1, pattern2]:
53+
matcher = GraphMatcher(pattern, allow_reorder=True)
54+
match_results = list(matcher.match_ops(ops))
55+
for match_result in match_results:
56+
old_output = match_result.get_op("output_eye_matrix")
57+
output_dtypes = [g.get_dtype(old_output.output[0])]
58+
output_shapes = [g.get_shape(old_output.output[0])]
59+
g.remove_node(old_output.name)
60+
61+
min_node = match_result.get_op("min_node")
62+
num_rows = min_node.inputs[0]
63+
num_columns = min_node.inputs[1]
64+
65+
# onnx op "EyeLike" need a 2D tensor, so generate it
66+
num_rows = g.make_node("Unsqueeze", num_rows.output, attr={"axes": [0]})
67+
num_columns = g.make_node("Unsqueeze", num_columns.output, attr={"axes": [0]})
68+
matrix_shape = g.make_node("Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0})
69+
# cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data.
70+
matrix_shape_int64 = g.make_node("Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64})
71+
zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output)
72+
73+
new_output = g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]},
74+
name=old_output.name, shapes=output_shapes, dtypes=output_dtypes)
75+
76+
return g.get_nodes()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def compat_handler(ctx, node, **kwargs):
771771
# bi-directional re-writer should be placed after single directional re-writer
772772
rewriters = [rewrite_transpose, rewrite_flatten,
773773
rewrite_random_uniform, rewrite_random_uniform_fold_const,
774-
rewrite_random_normal, rewrite_dropout,
774+
rewrite_random_normal, rewrite_dropout, rewrite_eye,
775775
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
776776
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
777777
rewrite_single_direction_gru, rewrite_bi_direction_gru,

0 commit comments

Comments
 (0)