@@ -131,22 +131,20 @@ def cond_graph2():
131
131
output_names_with_port = ["output:0" ]
132
132
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port )
133
133
134
- def test_while_loop_between_conds (self ):
134
+ def test_while_loop_in_cond (self ):
135
135
x_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
136
136
y_val = np .array ([4 , 5 , 6 ], dtype = np .float32 )
137
137
x = tf .placeholder (tf .float32 , x_val .shape , name = "input_1" )
138
138
y = tf .placeholder (tf .float32 , y_val .shape , name = "input_2" )
139
139
140
140
def cond_graph ():
141
141
b = tf .constant (np .array ([0 ], dtype = np .int32 ), dtype = tf .int32 )
142
- z = tf .gather_nd (x , b )
143
142
# while_loop
144
143
c = lambda y : tf .reduce_any (tf .less (y , 10 ))
145
144
b = lambda i : tf .add (y , 1 )
146
- r = tf .while_loop (c , b , [y ])
147
- return tf .cond (x [0 ] > y [0 ], lambda : z , lambda : r )
145
+ return tf .while_loop (c , b , [y ])
148
146
149
- res = x [ 2 ] * tf .cond (x [0 ] < y [0 ], lambda : x , cond_graph , name = "test_cond" )
147
+ res = tf .cond (x [0 ] < y [0 ], lambda : x , cond_graph , name = "test_cond" )
150
148
_ = tf .identity (res , name = "output" )
151
149
152
150
feed_dict = {"input_1:0" : x_val , "input_2:0" : y_val }
0 commit comments