Skip to content

Commit e0bbde8

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add while loop test
1 parent 8eff3fe commit e0bbde8

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,16 @@ def test_ScalarLoop():
404404
np.testing.assert_allclose(fn(5, 0, 1), 5)
405405
np.testing.assert_allclose(fn(5, 0, 2), 10)
406406
np.testing.assert_allclose(fn(4, 3, -1), -1)
407+
408+
409+
def test_ScalarLoop_while():
410+
n_steps = int64("n_steps")
411+
x0 = float64("x0")
412+
x = x0 + 1
413+
until = x >= 10
414+
415+
op = ScalarLoop(init=[x0], update=[x], until=until)
416+
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
417+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
418+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
419+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])

0 commit comments

Comments
 (0)