Skip to content

Commit b4c3773

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add while loop test
1 parent 03f46a6 commit b4c3773

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,16 @@ def test_ScalarLoop():
391391
np.testing.assert_allclose(fn(5, 0, 1), 5)
392392
np.testing.assert_allclose(fn(5, 0, 2), 10)
393393
np.testing.assert_allclose(fn(4, 3, -1), -1)
394+
395+
396+
def test_ScalarLoop_while():
397+
n_steps = int64("n_steps")
398+
x0 = float64("x0")
399+
x = x0 + 1
400+
until = x >= 10
401+
402+
op = ScalarLoop(init=[x0], update=[x], until=until)
403+
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
404+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
405+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
406+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])

0 commit comments

Comments
 (0)