Skip to content

Commit 977d98d

Browse files
author
Ian Schweer
committed
Add elemwise test
1 parent e06994f commit 977d98d

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def scalar_loop(steps, *start_and_constants):
7878
done = True
7979
for _ in range(steps):
8080
*carry, done = update(*carry, *constants)
81-
if done:
81+
if torch.any(done):
8282
break
8383
if len(node.outputs) == 2:
8484
return carry[0], done

tests/link/pytorch/test_basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.scalar import float64, int64
2121
from pytensor.scalar.loop import ScalarLoop
2222
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
23+
from pytensor.tensor.elemwise import Elemwise
2324
from pytensor.tensor.type import matrices, matrix, scalar, vector
2425

2526

@@ -420,3 +421,33 @@ def test_ScalarLoop_while():
420421
):
421422
np.testing.assert_allclose(res[0], np.array(expected[0]))
422423
np.testing.assert_allclose(res[1], np.array(expected[1]))
424+
425+
def test_pytorch_OpFromGraph():
426+
x, y, z = matrices("xyz")
427+
ofg_1 = OpFromGraph([x, y], [x + y])
428+
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
429+
430+
o1, o2 = ofg_2(y, z)
431+
out = ofg_1(x, o1) + o2
432+
433+
xv = np.ones((2, 2), dtype=config.floatX)
434+
yv = np.ones((2, 2), dtype=config.floatX) * 3
435+
zv = np.ones((2, 2), dtype=config.floatX) * 5
436+
437+
f = FunctionGraph([x, y, z], [out])
438+
compare_pytorch_and_py(f, [xv, yv, zv])
439+
440+
441+
def test_ScalarLoop_Elemwise():
442+
n_steps = int64("n_steps")
443+
x0 = float64("x0")
444+
x = x0 * 2
445+
until = x >= 10
446+
447+
op = ScalarLoop(init=[x0], update=[x], until=until)
448+
fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode)
449+
450+
states, dones = fn(10, np.array(range(5)))
451+
452+
np.testing.assert_allclose(states, [0, 4, 8, 12, 16])
453+
np.testing.assert_allclose(dones, [False, False, False, True, True])

0 commit comments

Comments
 (0)