Skip to content

Commit aa00703

Browse files
committed
Add while loop test
1 parent cdc06f9 commit aa00703

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,16 @@ def test_ScalarLoop():
318318
np.testing.assert_allclose(fn(5, 0, 1), 5)
319319
np.testing.assert_allclose(fn(5, 0, 2), 10)
320320
np.testing.assert_allclose(fn(4, 3, -1), -1)
321+
322+
323+
def test_ScalarLoop_while():
324+
n_steps = int64("n_steps")
325+
x0 = float64("x0")
326+
x = x0 + 1
327+
until = x >= 10
328+
329+
op = ScalarLoop(init=[x0], update=[x], until=until)
330+
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
331+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
332+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
333+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])

0 commit comments

Comments
 (0)