Skip to content

Commit f0507d5

Browse files
author
Ian Schweer
committed
Add elemwise test
1 parent eee63d4 commit f0507d5

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def scalar_loop(steps, *start_and_constants):
5454
done = True
5555
for _ in range(steps):
5656
*carry, done = update(*carry, *constants)
57-
if done:
57+
if torch.any(done):
5858
break
5959
if len(node.outputs) == 2:
6060
return carry[0], done

tests/link/pytorch/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.scalar import float64, int64
1818
from pytensor.scalar.loop import ScalarLoop
1919
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
20+
from pytensor.tensor.elemwise import Elemwise
2021
from pytensor.tensor.type import matrices, matrix, scalar, vector
2122

2223

@@ -351,3 +352,18 @@ def test_pytorch_OpFromGraph():
351352

352353
f = FunctionGraph([x, y, z], [out])
353354
compare_pytorch_and_py(f, [xv, yv, zv])
355+
356+
357+
def test_ScalarLoop_Elemwise():
358+
n_steps = int64("n_steps")
359+
x0 = float64("x0")
360+
x = x0 * 2
361+
until = x >= 10
362+
363+
op = ScalarLoop(init=[x0], update=[x], until=until)
364+
fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode)
365+
366+
states, dones = fn(10, np.array(range(5)))
367+
368+
np.testing.assert_allclose(states, [0, 4, 8, 12, 16])
369+
np.testing.assert_allclose(dones, [False, False, False, True, True])

0 commit comments

Comments
 (0)