Skip to content

Commit 4154157

Browse files
committed
refactor
1 parent e31d9a5 commit 4154157

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

tests/test_backend.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
_OUTPUT = "output:0"
4040
_TFOUTPUT1 = "output1"
4141
_OUTPUT1 = "output1:0"
42+
_TFOUTPUT2 = "output2"
43+
_OUTPUT2 = "output2:0"
4244

4345

4446
def make_xval(shape):
@@ -128,30 +130,34 @@ def test_expand_dims_more_unknown_rank(self):
128130
self._run_test_case([_OUTPUT], {_INPUT: x_val})
129131

130132
@check_opset_min_version(9, "ConstantOfShape")
131-
def test_eye(self):
132-
# tf.eye(tf.shape)
133+
def test_eye_non_const(self):
134+
# tf.eye(num_rows), num_rows is not const here
133135
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
134136
[tf.int32, tf.int64, tf.float32, tf.float64]):
135137
tf.reset_default_graph()
136-
x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype)
137-
x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT)
138-
y_ = tf.eye(tf.shape(x)[0], dtype=tf.float32)
139-
_ = tf.identity(y_, name=_TFOUTPUT)
140-
y1_ = tf.eye(tf.shape(x)[1], dtype=tf.int32)
141-
_ = tf.identity(y1_, name=_TFOUTPUT1)
142-
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0)
143-
144-
# tf.eye(tf.shape, tf.shape)
138+
x_val = np.array(5, dtype=np_dtype)
139+
x = tf.placeholder(tf_dtype, shape=[], name=_TFINPUT)
140+
y = tf.eye(x, dtype=tf.int32)
141+
_ = tf.identity(y, name=_TFOUTPUT)
142+
y1 = tf.eye(x, dtype=tf.int64)
143+
_ = tf.identity(y1, name=_TFOUTPUT1)
144+
y2 = tf.eye(x, dtype=tf.float32)
145+
_ = tf.identity(y2, name=_TFOUTPUT2)
146+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
147+
148+
# tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here
145149
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
146150
[tf.int32, tf.int64, tf.float32, tf.float64]):
147151
tf.reset_default_graph()
148-
x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype)
149-
x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT)
150-
y_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.float32)
151-
_ = tf.identity(y_, name=_TFOUTPUT)
152-
y1_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.int32)
153-
_ = tf.identity(y1_, name=_TFOUTPUT1)
154-
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0)
152+
x_val = np.array([5, 10], dtype=np_dtype)
153+
x = tf.placeholder(tf_dtype, shape=[2], name=_TFINPUT)
154+
y = tf.eye(x[0], x[1], dtype=tf.int32)
155+
_ = tf.identity(y, name=_TFOUTPUT)
156+
y1 = tf.eye(x[0], x[1], dtype=tf.int64)
157+
_ = tf.identity(y1, name=_TFOUTPUT1)
158+
y2 = tf.eye(x[0], x[1], dtype=tf.float32)
159+
_ = tf.identity(y2, name=_TFOUTPUT2)
160+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
155161

156162
@check_opset_min_version(7, "trig")
157163
def test_trig_ops(self):

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313

1414
def rewrite_eye(g, ops):
15+
# schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows
1516
# tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while
1617
# these two ops are un-supported directly in onnx
1718
# but onnx op EyeLike can be used to map the sub_graph
18-
# "rewrite_eye" supports tf.eye(tf.shape(x)[i]) and tf.eye(tf.shape(x)[i], tf.shape(x)[j]).
19+
# "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2).
20+
# tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter
1921

2022
# ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing
2123
if g.opset < 9:
@@ -24,12 +26,12 @@ def rewrite_eye(g, ops):
2426
pattern1 = \
2527
OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[
2628
OpTypePattern("Fill", inputs=[
27-
OpTypePattern("Const"),
29+
OpTypePattern("Const", name="fill_value"),
2830
OpTypePattern("ConcatV2", inputs=[
2931
"*",
3032
"*",
3133
OpTypePattern("Pack", inputs=[
32-
OpTypePattern("Minimum", name="min_node")
34+
OpTypePattern("Minimum|Cast", name="min_or_cast")
3335
])
3436
])
3537
])
@@ -38,12 +40,12 @@ def rewrite_eye(g, ops):
3840
OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[
3941
OpTypePattern("Fill"),
4042
OpTypePattern("Fill", inputs=[
41-
OpTypePattern("Const"),
43+
OpTypePattern("Const", name="fill_value"),
4244
OpTypePattern("ConcatV2", inputs=[
4345
"*",
4446
"*",
4547
OpTypePattern("Pack", inputs=[
46-
OpTypePattern("Minimum", name="min_node")
48+
OpTypePattern("Minimum|Cast", name="min_or_cast")
4749
])
4850
])
4951
])
@@ -53,15 +55,25 @@ def rewrite_eye(g, ops):
5355
matcher = GraphMatcher(pattern, allow_reorder=True)
5456
match_results = list(matcher.match_ops(ops))
5557
for match_result in match_results:
58+
if match_result.get_op("fill_value").get_tensor_value() != 1:
59+
continue
60+
61+
min_or_cast = match_result.get_op("min_or_cast")
62+
if min_or_cast.type == "Minimum":
63+
min_node = min_or_cast
64+
elif min_or_cast.type == "Cast" and min_or_cast.inputs[0].type == "Minimum":
65+
min_node = min_or_cast.inputs[0]
66+
else:
67+
continue
68+
69+
num_rows = min_node.inputs[0]
70+
num_columns = min_node.inputs[1]
71+
5672
old_output = match_result.get_op("output_eye_matrix")
5773
output_dtypes = [g.get_dtype(old_output.output[0])]
5874
output_shapes = [g.get_shape(old_output.output[0])]
5975
g.remove_node(old_output.name)
6076

61-
min_node = match_result.get_op("min_node")
62-
num_rows = min_node.inputs[0]
63-
num_columns = min_node.inputs[1]
64-
6577
# onnx op "EyeLike" need a 2D tensor, so generate it
6678
num_rows = g.make_node("Unsqueeze", num_rows.output, attr={"axes": [0]})
6779
num_columns = g.make_node("Unsqueeze", num_columns.output, attr={"axes": [0]})

0 commit comments

Comments
 (0)