Skip to content

Commit 8cbe16a

Browse files
committed
tf.eye has a bug when version is below 1.11.
1 parent 4154157 commit 8cbe16a

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

tests/test_backend.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,34 @@ def test_expand_dims_more_unknown_rank(self):
130130
self._run_test_case([_OUTPUT], {_INPUT: x_val})
131131

132132
@check_opset_min_version(9, "ConstantOfShape")
133-
def test_eye_non_const(self):
133+
def test_eye_non_const1(self):
134+
# tf.eye(num_rows), num_rows is not const here
135+
tf.reset_default_graph()
136+
x_val = np.array(5, dtype=np.int32)
137+
x = tf.placeholder(tf.int32, shape=[], name=_TFINPUT)
138+
y = tf.eye(x, dtype=tf.int32)
139+
_ = tf.identity(y, name=_TFOUTPUT)
140+
y1 = tf.eye(x, dtype=tf.int64)
141+
_ = tf.identity(y1, name=_TFOUTPUT1)
142+
y2 = tf.eye(x, dtype=tf.float32)
143+
_ = tf.identity(y2, name=_TFOUTPUT2)
144+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
145+
146+
# tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here
147+
tf.reset_default_graph()
148+
x_val = np.array([5, 10], dtype=np.int32)
149+
x = tf.placeholder(tf.int32, shape=[2], name=_TFINPUT)
150+
y = tf.eye(x[0], x[1], dtype=tf.int32)
151+
_ = tf.identity(y, name=_TFOUTPUT)
152+
y1 = tf.eye(x[0], x[1], dtype=tf.int64)
153+
_ = tf.identity(y1, name=_TFOUTPUT1)
154+
y2 = tf.eye(x[0], x[1], dtype=tf.float32)
155+
_ = tf.identity(y2, name=_TFOUTPUT2)
156+
self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0)
157+
158+
@check_tf_min_version("1.11", "eye has bug when version is below 1.11")
159+
@check_opset_min_version(9, "ConstantOfShape")
160+
def test_eye_non_const2(self):
134161
# tf.eye(num_rows), num_rows is not const here
135162
for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
136163
[tf.int32, tf.int64, tf.float32, tf.float64]):

0 commit comments

Comments
 (0)