|
39 | 39 | _OUTPUT = "output:0"
|
40 | 40 | _TFOUTPUT1 = "output1"
|
41 | 41 | _OUTPUT1 = "output1:0"
|
| 42 | +_TFOUTPUT2 = "output2" |
| 43 | +_OUTPUT2 = "output2:0" |
42 | 44 |
|
43 | 45 |
|
44 | 46 | def make_xval(shape):
|
@@ -127,6 +129,63 @@ def test_expand_dims_more_unknown_rank(self):
|
127 | 129 | _ = tf.identity(op, name=_TFOUTPUT)
|
128 | 130 | self._run_test_case([_OUTPUT], {_INPUT: x_val})
|
129 | 131 |
|
| 132 | + @check_opset_min_version(9, "ConstantOfShape") |
| 133 | + def test_eye_non_const1(self): |
| 134 | + # tf.eye(num_rows), num_rows is not const here |
| 135 | + tf.reset_default_graph() |
| 136 | + x_val = np.array(5, dtype=np.int32) |
| 137 | + x = tf.placeholder(tf.int32, shape=[], name=_TFINPUT) |
| 138 | + y = tf.eye(x, dtype=tf.int32) |
| 139 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 140 | + y1 = tf.eye(x, dtype=tf.int64) |
| 141 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 142 | + y2 = tf.eye(x, dtype=tf.float32) |
| 143 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 144 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
| 145 | + |
| 146 | + # tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here |
| 147 | + tf.reset_default_graph() |
| 148 | + x_val = np.array([5, 10], dtype=np.int32) |
| 149 | + x = tf.placeholder(tf.int32, shape=[2], name=_TFINPUT) |
| 150 | + y = tf.eye(x[0], x[1], dtype=tf.int32) |
| 151 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 152 | + y1 = tf.eye(x[0], x[1], dtype=tf.int64) |
| 153 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 154 | + y2 = tf.eye(x[0], x[1], dtype=tf.float32) |
| 155 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 156 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
| 157 | + |
| 158 | + @check_tf_min_version("1.11", "eye has bug when version is below 1.11") |
| 159 | + @check_opset_min_version(9, "ConstantOfShape") |
| 160 | + def test_eye_non_const2(self): |
| 161 | + # tf.eye(num_rows), num_rows is not const here |
| 162 | + for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64], |
| 163 | + [tf.int32, tf.int64, tf.float32, tf.float64]): |
| 164 | + tf.reset_default_graph() |
| 165 | + x_val = np.array(5, dtype=np_dtype) |
| 166 | + x = tf.placeholder(tf_dtype, shape=[], name=_TFINPUT) |
| 167 | + y = tf.eye(x, dtype=tf.int32) |
| 168 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 169 | + y1 = tf.eye(x, dtype=tf.int64) |
| 170 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 171 | + y2 = tf.eye(x, dtype=tf.float32) |
| 172 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 173 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
| 174 | + |
| 175 | + # tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here |
| 176 | + for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64], |
| 177 | + [tf.int32, tf.int64, tf.float32, tf.float64]): |
| 178 | + tf.reset_default_graph() |
| 179 | + x_val = np.array([5, 10], dtype=np_dtype) |
| 180 | + x = tf.placeholder(tf_dtype, shape=[2], name=_TFINPUT) |
| 181 | + y = tf.eye(x[0], x[1], dtype=tf.int32) |
| 182 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 183 | + y1 = tf.eye(x[0], x[1], dtype=tf.int64) |
| 184 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 185 | + y2 = tf.eye(x[0], x[1], dtype=tf.float32) |
| 186 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 187 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
| 188 | + |
130 | 189 | @check_opset_min_version(7, "trig")
|
131 | 190 | def test_trig_ops(self):
|
132 | 191 | for op in [tf.sin, tf.cos, tf.tan, tf.asin, tf.acos, tf.atan]:
|
|
0 commit comments