@@ -79,8 +79,38 @@ def rewrite_eye(g, ops):
79
79
OpTypePattern ("Const" , name = "fill_value" ),
80
80
]),
81
81
])
82
+ pattern5 = \
83
+ OpTypePattern ("MatrixDiagV3" , name = "output_eye_matrix" , inputs = [
84
+ OpTypePattern ("Fill" , inputs = [
85
+ OpTypePattern ("ConcatV2" , inputs = [
86
+ "*" ,
87
+ OpTypePattern ("ExpandDims" , inputs = [
88
+ OpTypePattern ("Minimum|Cast" , name = "min_or_cast" ),
89
+ "*"
90
+ ]),
91
+ "*" ,
92
+ ]),
93
+ OpTypePattern ("Const" , name = "fill_value" ),
94
+ ]),
95
+ "*" , "*" , "*" , "*" ,
96
+ ])
97
+ pattern6 = \
98
+ OpTypePattern ("MatrixSetDiagV3" , name = "output_eye_matrix" , inputs = [
99
+ OpTypePattern ("Fill" ),
100
+ OpTypePattern ("Fill" , inputs = [
101
+ OpTypePattern ("ConcatV2" , inputs = [
102
+ "*" ,
103
+ OpTypePattern ("ExpandDims" , inputs = [
104
+ OpTypePattern ("Minimum|Cast" , name = "min_or_cast" ),
105
+ "*"
106
+ ]),
107
+ "*" ,
108
+ ]),
109
+ OpTypePattern ("Const" , name = "fill_value" ),
110
+ ]), "*"
111
+ ])
82
112
83
- for pattern in [pattern1 , pattern2 , pattern3 , pattern4 ]:
113
+ for pattern in [pattern1 , pattern2 , pattern3 , pattern4 , pattern5 , pattern6 ]:
84
114
matcher = GraphMatcher (pattern , allow_reorder = True )
85
115
match_results = list (matcher .match_ops (ops ))
86
116
for match_result in match_results :
0 commit comments