Skip to content

Commit 141b008

Browse files
committed
add test for expand dims for more unkonwn rank with negative dims
1 parent fdbbd25 commit 141b008

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

tests/test_backend.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,17 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
105105
kwargs["constant_fold"] = False
106106
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
107107

108-
def _test_expand_dims(self, idx):
108+
def _test_expand_dims_known_rank(self, idx):
109109
tf.reset_default_graph()
110110
x_val = make_xval([3, 4])
111111
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
112112
op = tf.expand_dims(x, idx)
113113
_ = tf.identity(op, name=_TFOUTPUT)
114114
self._run_test_case([_OUTPUT], {_INPUT: x_val})
115115

116-
def test_expand_dims(self):
116+
def test_expand_dims_known_rank(self):
117117
for i in [-1, 0, 1, -2]:
118-
self._test_expand_dims(i)
119-
120-
def test_expand_dims_dynamic_inputs(self):
121-
x_val = make_xval([3, 4])
122-
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
123-
op = tf.expand_dims(x, 0)
124-
_ = tf.identity(op, name=_TFOUTPUT)
125-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
118+
self._test_expand_dims_known_rank(i)
126119

127120
def test_expand_dims_one_unknown_rank(self):
128121
tf.reset_default_graph()
@@ -132,14 +125,18 @@ def test_expand_dims_one_unknown_rank(self):
132125
_ = tf.identity(op, name=_TFOUTPUT)
133126
self._run_test_case([_OUTPUT], {_INPUT: x_val})
134127

135-
def test_expand_dims_more_unknown_rank(self):
128+
def _test_expand_dims_more_unknown_rank(self, idx):
136129
tf.reset_default_graph()
137130
x_val = make_xval([3, 4])
138131
x = tf.placeholder(tf.float32, shape=[None, None], name=_TFINPUT)
139-
op = tf.expand_dims(x, 0)
132+
op = tf.expand_dims(x, idx)
140133
_ = tf.identity(op, name=_TFOUTPUT)
141134
self._run_test_case([_OUTPUT], {_INPUT: x_val})
142135

136+
def test_expand_dims_more_unknown_rank(self):
137+
for i in [-1, 0, 1, -2]:
138+
self._test_expand_dims_more_unknown_rank(i)
139+
143140
@check_opset_min_version(9, "ConstantOfShape")
144141
def test_eye_non_const1(self):
145142
# tf.eye(num_rows), num_rows is not const here

0 commit comments

Comments
 (0)