Skip to content

Commit 60c0527

Browse files
author
wayuanho
committed
fix cond_test bug
1 parent 48ca297 commit 60c0527

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/test_cond.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def cond_graph2():
131131
output_names_with_port = ["output:0"]
132132
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
133133

134-
def test_while_loop_between_conds(self):
134+
def test_while_loop_in_cond(self):
135135
x_val = np.array([1, 2, 3], dtype=np.float32)
136136
y_val = np.array([4, 5, 6], dtype=np.float32)
137137
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
@@ -143,10 +143,9 @@ def cond_graph():
143143
# while_loop
144144
c = lambda y: tf.reduce_any(tf.less(y, 10))
145145
b = lambda i: tf.add(y, 1)
146-
r = tf.while_loop(c, b, [y])
147-
return tf.cond(x[0] > y[0], lambda: z, lambda: r)
146+
return tf.while_loop(c, b, [y])
148147

149-
res = x[2] * tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond")
148+
res = tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond")
150149
_ = tf.identity(res, name="output")
151150

152151
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}

0 commit comments

Comments
 (0)