Skip to content

Commit bf05899

Browse files
committed
fix eye rewriter
1 parent 7d622fe commit bf05899

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

tests/test_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def test_expand_dims_more_unknown_rank(self):
202202
self._test_expand_dims_more_unknown_rank(i)
203203

204204
@check_opset_min_version(9, "ConstantOfShape")
205-
@check_opset_after_tf_version("2.2", 12, "MatrixDiag")
206205
def test_eye_non_const1(self):
207206
# tf.eye(num_rows), num_rows is not const here
208207
x_val = np.array(5, dtype=np.int32)
@@ -224,7 +223,6 @@ def func(x):
224223

225224
@check_tf_min_version("1.11", "eye has bug when version is below 1.11")
226225
@check_opset_min_version(9, "ConstantOfShape")
227-
@check_opset_after_tf_version("2.2", 12, "MatrixDiag")
228226
def test_eye_non_const2(self):
229227
# tf.eye(num_rows), num_rows is not const here
230228
for np_dtype in [np.int32, np.int64, np.float32, np.float64]:

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,38 @@ def rewrite_eye(g, ops):
7979
OpTypePattern("Const", name="fill_value"),
8080
]),
8181
])
82+
pattern5 = \
83+
OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[
84+
OpTypePattern("Fill", inputs=[
85+
OpTypePattern("ConcatV2", inputs=[
86+
"*",
87+
OpTypePattern("ExpandDims", inputs=[
88+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
89+
"*"
90+
]),
91+
"*",
92+
]),
93+
OpTypePattern("Const", name="fill_value"),
94+
]),
95+
"*", "*", "*", "*",
96+
])
97+
pattern6 = \
98+
OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[
99+
OpTypePattern("Fill"),
100+
OpTypePattern("Fill", inputs=[
101+
OpTypePattern("ConcatV2", inputs=[
102+
"*",
103+
OpTypePattern("ExpandDims", inputs=[
104+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
105+
"*"
106+
]),
107+
"*",
108+
]),
109+
OpTypePattern("Const", name="fill_value"),
110+
]), "*"
111+
])
82112

83-
for pattern in [pattern1, pattern2, pattern3, pattern4]:
113+
for pattern in [pattern1, pattern2, pattern3, pattern4, pattern5, pattern6]:
84114
matcher = GraphMatcher(pattern, allow_reorder=True)
85115
match_results = list(matcher.match_ops(ops))
86116
for match_result in match_results:

0 commit comments

Comments
 (0)