Skip to content

Commit 56a6e93

Browse files
author
Ian Schweer
committed
Update test
1 parent cc2bfb4 commit 56a6e93

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def cast(x):
5252

5353
return cast
5454

55+
5556
@pytorch_funcify.register(ScalarLoop)
5657
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
5758
update = pytorch_funcify(op.fgraph)
@@ -68,7 +69,7 @@ def scalar_loop(steps, *start_and_constants):
6869
*carry, done = update(*carry, *constants)
6970
if torch.any(done):
7071
break
71-
return *carry, done
72+
return *carry, done
7273
else:
7374

7475
def scalar_loop(steps, *start_and_constants):

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,25 @@ def test_ScalarLoop_while():
384384
def test_ScalarLoop_Elemwise():
385385
n_steps = int64("n_steps")
386386
x0 = float64("x0")
387+
x1 = float64("x1")
387388
x = x0 * 2
389+
x1_n = x1 * 3
388390
until = x >= 10
389391

390-
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
392+
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
391393
op = Elemwise(scalarop)
392394

393395
n_steps = pt.scalar("n_steps", dtype="int32")
394396
x0 = pt.vector("x0", dtype="float32")
395-
state, done = op(n_steps, x0)
396-
397-
f = FunctionGraph([n_steps, x0], [state, done])
398-
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
399-
compare_pytorch_and_py(f, args)
397+
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
398+
*states, done = op(n_steps, x0, x1)
399+
400+
f = FunctionGraph([n_steps, x0, x1], [*states, done])
401+
args = [
402+
np.array(10).astype("int32"),
403+
np.arange(0, 5).astype("float32"),
404+
np.random.rand(7, 3, 1).astype("float32"),
405+
]
406+
compare_pytorch_and_py(
407+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
408+
)

0 commit comments

Comments
 (0)