|
39 | 39 | _OUTPUT = "output:0"
|
40 | 40 | _TFOUTPUT1 = "output1"
|
41 | 41 | _OUTPUT1 = "output1:0"
|
| 42 | +_TFOUTPUT2 = "output2" |
| 43 | +_OUTPUT2 = "output2:0" |
42 | 44 |
|
43 | 45 |
|
44 | 46 | def make_xval(shape):
|
@@ -128,30 +130,34 @@ def test_expand_dims_more_unknown_rank(self):
|
128 | 130 | self._run_test_case([_OUTPUT], {_INPUT: x_val})
|
129 | 131 |
|
130 | 132 | @check_opset_min_version(9, "ConstantOfShape")
|
131 |
| - def test_eye(self): |
132 |
| - # tf.eye(tf.shape) |
| 133 | + def test_eye_non_const(self): |
| 134 | + # tf.eye(num_rows), num_rows is not const here |
133 | 135 | for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
|
134 | 136 | [tf.int32, tf.int64, tf.float32, tf.float64]):
|
135 | 137 | tf.reset_default_graph()
|
136 |
| - x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype) |
137 |
| - x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT) |
138 |
| - y_ = tf.eye(tf.shape(x)[0], dtype=tf.float32) |
139 |
| - _ = tf.identity(y_, name=_TFOUTPUT) |
140 |
| - y1_ = tf.eye(tf.shape(x)[1], dtype=tf.int32) |
141 |
| - _ = tf.identity(y1_, name=_TFOUTPUT1) |
142 |
| - self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0) |
143 |
| - |
144 |
| - # tf.eye(tf.shape, tf.shape) |
| 138 | + x_val = np.array(5, dtype=np_dtype) |
| 139 | + x = tf.placeholder(tf_dtype, shape=[], name=_TFINPUT) |
| 140 | + y = tf.eye(x, dtype=tf.int32) |
| 141 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 142 | + y1 = tf.eye(x, dtype=tf.int64) |
| 143 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 144 | + y2 = tf.eye(x, dtype=tf.float32) |
| 145 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 146 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
| 147 | + |
| 148 | + # tf.eye(num_rows, num_columns), both num_rows and num_columns are not const here |
145 | 149 | for np_dtype, tf_dtype in zip([np.int32, np.int64, np.float32, np.float64],
|
146 | 150 | [tf.int32, tf.int64, tf.float32, tf.float64]):
|
147 | 151 | tf.reset_default_graph()
|
148 |
| - x_val = np.array([[1.0, 2.0, -3.0, -4.0, 5.0]] * 2, dtype=np_dtype) |
149 |
| - x = tf.placeholder(tf_dtype, shape=[None] * 2, name=_TFINPUT) |
150 |
| - y_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.float32) |
151 |
| - _ = tf.identity(y_, name=_TFOUTPUT) |
152 |
| - y1_ = tf.eye(tf.shape(x)[0], tf.shape(x)[1], dtype=tf.int32) |
153 |
| - _ = tf.identity(y1_, name=_TFOUTPUT1) |
154 |
| - self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}, rtol=0) |
| 152 | + x_val = np.array([5, 10], dtype=np_dtype) |
| 153 | + x = tf.placeholder(tf_dtype, shape=[2], name=_TFINPUT) |
| 154 | + y = tf.eye(x[0], x[1], dtype=tf.int32) |
| 155 | + _ = tf.identity(y, name=_TFOUTPUT) |
| 156 | + y1 = tf.eye(x[0], x[1], dtype=tf.int64) |
| 157 | + _ = tf.identity(y1, name=_TFOUTPUT1) |
| 158 | + y2 = tf.eye(x[0], x[1], dtype=tf.float32) |
| 159 | + _ = tf.identity(y2, name=_TFOUTPUT2) |
| 160 | + self._run_test_case([_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val}, rtol=0) |
155 | 161 |
|
156 | 162 | @check_opset_min_version(7, "trig")
|
157 | 163 | def test_trig_ops(self):
|
|
0 commit comments