@@ -105,24 +105,17 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
105
105
kwargs ["constant_fold" ] = False
106
106
return self .run_test_case (feed_dict , [], output_names_with_port , ** kwargs )
107
107
108
- def _test_expand_dims (self , idx ):
108
+ def _test_expand_dims_known_rank (self , idx ):
109
109
tf .reset_default_graph ()
110
110
x_val = make_xval ([3 , 4 ])
111
111
x = tf .placeholder (tf .float32 , shape = x_val .shape , name = _TFINPUT )
112
112
op = tf .expand_dims (x , idx )
113
113
_ = tf .identity (op , name = _TFOUTPUT )
114
114
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
115
115
116
- def test_expand_dims (self ):
116
+ def test_expand_dims_known_rank (self ):
117
117
for i in [- 1 , 0 , 1 , - 2 ]:
118
- self ._test_expand_dims (i )
119
-
120
- def test_expand_dims_dynamic_inputs (self ):
121
- x_val = make_xval ([3 , 4 ])
122
- x = tf .placeholder (tf .float32 , shape = [None , None ], name = _TFINPUT )
123
- op = tf .expand_dims (x , 0 )
124
- _ = tf .identity (op , name = _TFOUTPUT )
125
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
118
+ self ._test_expand_dims_known_rank (i )
126
119
127
120
def test_expand_dims_one_unknown_rank (self ):
128
121
tf .reset_default_graph ()
@@ -132,14 +125,18 @@ def test_expand_dims_one_unknown_rank(self):
132
125
_ = tf .identity (op , name = _TFOUTPUT )
133
126
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
134
127
135
- def test_expand_dims_more_unknown_rank (self ):
128
+ def _test_expand_dims_more_unknown_rank (self , idx ):
136
129
tf .reset_default_graph ()
137
130
x_val = make_xval ([3 , 4 ])
138
131
x = tf .placeholder (tf .float32 , shape = [None , None ], name = _TFINPUT )
139
- op = tf .expand_dims (x , 0 )
132
+ op = tf .expand_dims (x , idx )
140
133
_ = tf .identity (op , name = _TFOUTPUT )
141
134
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
142
135
136
+ def test_expand_dims_more_unknown_rank (self ):
137
+ for i in [- 1 , 0 , 1 , - 2 ]:
138
+ self ._test_expand_dims_more_unknown_rank (i )
139
+
143
140
@check_opset_min_version (9 , "ConstantOfShape" )
144
141
def test_eye_non_const1 (self ):
145
142
# tf.eye(num_rows), num_rows is not const here
0 commit comments