Skip to content

Commit 82e0c80

Browse files
authored
Merge pull request #308 from zhijxu-MS/tmp_branch_for_PR2
add tf.eye supported partially
2 parents a88c048 + 8cbe16a commit 82e0c80

File tree

4 files changed

+150
-1
lines changed

4 files changed

+150
-1
lines changed

tests/test_backend.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
_OUTPUT = "output:0"
4040
_TFOUTPUT1 = "output1"
4141
_OUTPUT1 = "output1:0"
42+
_TFOUTPUT2 = "output2"
43+
_OUTPUT2 = "output2:0"
4244

4345

4446
def make_xval(shape):
@@ -127,6 +129,63 @@ def test_expand_dims_more_unknown_rank(self):
127129
_ = tf.identity(op, name=_TFOUTPUT)
128130
self._run_test_case([_OUTPUT], {_INPUT: x_val})
129131

132+
@check_opset_min_version(9, "ConstantOfShape")
133+
def test_eye_non_const1(self):
134+
# tf.eye(num_rows), num_rows is not const here
135+
tf.reset_default_graph()
136+
x_val = np.array(5, dtype=np.int32)
137+
x = tf.placeholder(tf.int32, shape=[], name=_TFINPUT)
138+
y = tf.eye(x, dtype=tf.int32)
139+
_ = tf.identity(y, name=_TFOUTPUT)
140+
y1 = tf.eye(x, dtype=tf.int64)
141+
_ = tf.identity(y1, name=_TFOUTPUT1)
142+
y2 = tf.eye(x, dtype=tf.float32)
143+
_ = tf.identity(y2, name=_TFOUTPUT2)
144+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
145+
146+
# tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here
147+
tf.reset_default_graph()
148+
x_val = np.array([5, 10], dtype=np.int32)
149+
x = tf.placeholder(tf.int32, shape=[2], name=_TFINPUT)
150+
y = tf.eye(x[0], x[1], dtype=tf.int32)
151+
_ = tf.identity(y, name=_TFOUTPUT)
152+
y1 = tf.eye(x[0], x[1], dtype=tf.int64)
153+
_ = tf.identity(y1, name=_TFOUTPUT1)
154+
y2 = tf.eye(x[0], x[1], dtype=tf.float32)
155+
_ = tf.identity(y2, name=_TFOUTPUT2)
156+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
157+
158+
@check_tf_min_version("1.11", "eye has bug when version is below 1.11")
159+
@check_opset_min_version(9, "ConstantOfShape")
160+
def test_eye_non_const2(self):
161+
# tf.eye(num_rows), num_rows is not const here
162+
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
163+
[tf.int32, tf.int64, tf.float32, tf.float64]):
164+
tf.reset_default_graph()
165+
x_val = np.array(5, dtype=np_dtype)
166+
x = tf.placeholder(tf_dtype, shape=[], name=_TFINPUT)
167+
y = tf.eye(x, dtype=tf.int32)
168+
_ = tf.identity(y, name=_TFOUTPUT)
169+
y1 = tf.eye(x, dtype=tf.int64)
170+
_ = tf.identity(y1, name=_TFOUTPUT1)
171+
y2 = tf.eye(x, dtype=tf.float32)
172+
_ = tf.identity(y2, name=_TFOUTPUT2)
173+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
174+
175+
# tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here
176+
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
177+
[tf.int32, tf.int64, tf.float32, tf.float64]):
178+
tf.reset_default_graph()
179+
x_val = np.array([5, 10], dtype=np_dtype)
180+
x = tf.placeholder(tf_dtype, shape=[2], name=_TFINPUT)
181+
y = tf.eye(x[0], x[1], dtype=tf.int32)
182+
_ = tf.identity(y, name=_TFOUTPUT)
183+
y1 = tf.eye(x[0], x[1], dtype=tf.int64)
184+
_ = tf.identity(y1, name=_TFOUTPUT1)
185+
y2 = tf.eye(x[0], x[1], dtype=tf.float32)
186+
_ = tf.identity(y2, name=_TFOUTPUT2)
187+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
188+
130189
@check_opset_min_version(7, "trig")
131190
def test_trig_ops(self):
132191
for op in [tf.sin, tf.cos, tf.tan, tf.asin, tf.acos, tf.atan]:

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1111
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12+
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
1213
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
1314
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1415
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
@@ -20,6 +21,7 @@
2021
"rewrite_random_uniform_fold_const",
2122
"rewrite_leakyrelu",
2223
"rewrite_thresholded_relu",
24+
"rewrite_eye",
2325
"rewrite_single_direction_lstm",
2426
"rewrite_bi_direction_lstm",
2527
"rewrite_single_direction_gru",

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter.eye_rewriter - supports tf.eye
6+
"""
7+
8+
from onnx import onnx_pb
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
12+
13+
14+
def rewrite_eye(g, ops):
15+
# schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows
16+
# tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while
17+
# these two ops are un-supported directly in onnx
18+
# but onnx op EyeLike can be used to map the sub_graph
19+
# "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2).
20+
# tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter
21+
22+
# ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing
23+
if g.opset < 9:
24+
return g.get_nodes()
25+
26+
pattern1 = \
27+
OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
28+
OpTypePattern("Fill", inputs=[
29+
OpTypePattern("Const", name="fill_value"),
30+
OpTypePattern("ConcatV2", inputs=[
31+
"*",
32+
"*",
33+
OpTypePattern("Pack", inputs=[
34+
OpTypePattern("Minimum|Cast", name="min_or_cast")
35+
])
36+
])
37+
])
38+
])
39+
pattern2 = \
40+
OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
41+
OpTypePattern("Fill"),
42+
OpTypePattern("Fill", inputs=[
43+
OpTypePattern("Const", name="fill_value"),
44+
OpTypePattern("ConcatV2", inputs=[
45+
"*",
46+
"*",
47+
OpTypePattern("Pack", inputs=[
48+
OpTypePattern("Minimum|Cast", name="min_or_cast")
49+
])
50+
])
51+
])
52+
])
53+
54+
for pattern in [pattern1, pattern2]:
55+
matcher = GraphMatcher(pattern, allow_reorder=True)
56+
match_results = list(matcher.match_ops(ops))
57+
for match_result in match_results:
58+
if match_result.get_op("fill_value").get_tensor_value() != 1:
59+
continue
60+
61+
min_or_cast = match_result.get_op("min_or_cast")
62+
if min_or_cast.type == "Minimum":
63+
min_node = min_or_cast
64+
elif min_or_cast.type == "Cast" and min_or_cast.inputs[0].type == "Minimum":
65+
min_node = min_or_cast.inputs[0]
66+
else:
67+
continue
68+
69+
num_rows = min_node.inputs[0]
70+
num_columns = min_node.inputs[1]
71+
72+
old_output = match_result.get_op("output_eye_matrix")
73+
output_dtypes = [g.get_dtype(old_output.output[0])]
74+
output_shapes = [g.get_shape(old_output.output[0])]
75+
g.remove_node(old_output.name)
76+
77+
# onnx op "EyeLike" need a 2D tensor, so generate it
78+
num_rows = g.make_node("Unsqueeze", num_rows.output, attr={"axes": [0]})
79+
num_columns = g.make_node("Unsqueeze", num_columns.output, attr={"axes": [0]})
80+
matrix_shape = g.make_node("Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0})
81+
# cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data.
82+
matrix_shape_int64 = g.make_node("Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64})
83+
zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output)
84+
85+
new_output = g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]},
86+
name=old_output.name, shapes=output_shapes, dtypes=output_dtypes)
87+
88+
return g.get_nodes()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def compat_handler(ctx, node, **kwargs):
771771
# bi-directional re-writer should be placed after single directional re-writer
772772
rewriters = [rewrite_transpose, rewrite_flatten,
773773
rewrite_random_uniform, rewrite_random_uniform_fold_const,
774-
rewrite_random_normal, rewrite_dropout,
774+
rewrite_random_normal, rewrite_dropout, rewrite_eye,
775775
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
776776
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
777777
rewrite_single_direction_gru, rewrite_bi_direction_gru,

0 commit comments

Comments
 (0)