Skip to content

Commit d3c9f8c

Browse files
authored
Merge pull request #343 from lucienwang1009/cond_test_bug
fix cond_test bug
2 parents 48ca297 + 4707029 commit d3c9f8c

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tests/test_cond.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,20 @@ def cond_graph2():
131131
output_names_with_port = ["output:0"]
132132
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
133133

134-
def test_while_loop_between_conds(self):
134+
def test_while_loop_in_cond(self):
135135
x_val = np.array([1, 2, 3], dtype=np.float32)
136136
y_val = np.array([4, 5, 6], dtype=np.float32)
137137
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
138138
y = tf.placeholder(tf.float32, y_val.shape, name="input_2")
139139

140140
def cond_graph():
141141
b = tf.constant(np.array([0], dtype=np.int32), dtype=tf.int32)
142-
z = tf.gather_nd(x, b)
143142
# while_loop
144143
c = lambda y: tf.reduce_any(tf.less(y, 10))
145144
b = lambda i: tf.add(y, 1)
146-
r = tf.while_loop(c, b, [y])
147-
return tf.cond(x[0] > y[0], lambda: z, lambda: r)
145+
return tf.while_loop(c, b, [y])
148146

149-
res = x[2] * tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond")
147+
res = tf.cond(x[0] < y[0], lambda: x, cond_graph, name="test_cond")
150148
_ = tf.identity(res, name="output")
151149

152150
feed_dict = {"input_1:0": x_val, "input_2:0": y_val}

0 commit comments

Comments
 (0)