@@ -130,7 +130,34 @@ def test_expand_dims_more_unknown_rank(self):
130
130
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
131
131
132
132
@check_opset_min_version (9 , "ConstantOfShape" )
133
- def test_eye_non_const (self ):
133
+ def test_eye_non_const1 (self ):
134
+ # tf.eye(num_rows), num_rows is not const here
135
+ tf .reset_default_graph ()
136
+ x_val = np .array (5 , dtype = np .int32 )
137
+ x = tf .placeholder (tf .int32 , shape = [], name = _TFINPUT )
138
+ y = tf .eye (x , dtype = tf .int32 )
139
+ _ = tf .identity (y , name = _TFOUTPUT )
140
+ y1 = tf .eye (x , dtype = tf .int64 )
141
+ _ = tf .identity (y1 , name = _TFOUTPUT1 )
142
+ y2 = tf .eye (x , dtype = tf .float32 )
143
+ _ = tf .identity (y2 , name = _TFOUTPUT2 )
144
+ self ._run_test_case ([_OUTPUT , _OUTPUT1 , _OUTPUT2 ], {_INPUT : x_val }, rtol = 0 )
145
+
146
+ # tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here
147
+ tf .reset_default_graph ()
148
+ x_val = np .array ([5 , 10 ], dtype = np .int32 )
149
+ x = tf .placeholder (tf .int32 , shape = [2 ], name = _TFINPUT )
150
+ y = tf .eye (x [0 ], x [1 ], dtype = tf .int32 )
151
+ _ = tf .identity (y , name = _TFOUTPUT )
152
+ y1 = tf .eye (x [0 ], x [1 ], dtype = tf .int64 )
153
+ _ = tf .identity (y1 , name = _TFOUTPUT1 )
154
+ y2 = tf .eye (x [0 ], x [1 ], dtype = tf .float32 )
155
+ _ = tf .identity (y2 , name = _TFOUTPUT2 )
156
+ self ._run_test_case ([_OUTPUT , _OUTPUT1 , _OUTPUT2 ], {_INPUT : x_val }, rtol = 0 )
157
+
158
+ @check_tf_min_version ("1.11" , "eye has bug when version is below 1.11" )
159
+ @check_opset_min_version (9 , "ConstantOfShape" )
160
+ def test_eye_non_const2 (self ):
134
161
# tf.eye(num_rows), num_rows is not const here
135
162
for np_dtype , tf_dtype in zip ([np .int32 , np .int64 , np .float32 , np .float64 ],
136
163
[tf .int32 , tf .int64 , tf .float32 , tf .float64 ]):
0 commit comments