Skip to content

Commit d0a8120

Browse files
author
Ian Schweer
committed
Rearrange test for code coverage
1 parent 0bf00ab commit d0a8120

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,14 @@ def test_pytorch_ifelse():
308308
p1_vals = np.r_[1, 2, 3]
309309
p2_vals = np.r_[-1, -2, -3]
310310

311-
for test_value, cond in [(0.2, 0.5), (0.5, 0.4)]:
312-
a = scalar("a")
313-
x = ifelse(
314-
a < cond, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])
315-
)
316-
x_fg = FunctionGraph([a], x)
317-
318-
compare_pytorch_and_py(x_fg, np.array([test_value], dtype=config.floatX))
311+
a = scalar("a")
312+
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
313+
x_fg = FunctionGraph([a], x)
314+
315+
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))
316+
317+
a = scalar("a")
318+
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
319+
x_fg = FunctionGraph([a], x)
320+
321+
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))

0 commit comments

Comments
 (0)