Skip to content

Commit 9aeed89

Browse files
author
Ian Schweer
committed
Add single carry test
1 parent 56a6e93 commit 9aeed89

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,30 @@ def test_ScalarLoop_while():
381381
np.testing.assert_allclose(res[1], np.array(expected[1]))
382382

383383

384-
def test_ScalarLoop_Elemwise():
384+
def test_ScalarLoop_Elemwise_single_carries():
385+
n_steps = int64("n_steps")
386+
x0 = float64("x0")
387+
x = x0 * 2
388+
until = x >= 10
389+
390+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
391+
op = Elemwise(scalarop)
392+
393+
n_steps = pt.scalar("n_steps", dtype="int32")
394+
x0 = pt.vector("x0", dtype="float32")
395+
state, done = op(n_steps, x0)
396+
397+
f = FunctionGraph([n_steps, x0], [state, done])
398+
args = [
399+
np.array(10).astype("int32"),
400+
np.arange(0, 5).astype("float32"),
401+
]
402+
compare_pytorch_and_py(
403+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
404+
)
405+
406+
407+
def test_ScalarLoop_Elemwise_multi_carries():
385408
n_steps = int64("n_steps")
386409
x0 = float64("x0")
387410
x1 = float64("x1")

0 commit comments

Comments
 (0)