Skip to content

Commit 4990e3a

Browse files
author
Ian Schweer
committed
Fix while loop and nasty stack over dtypes
1 parent b4c3773 commit 4990e3a

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def scalar_loop(steps, *start_and_constants):
6767
*carry, done = update(*carry, *constants)
6868
if done:
6969
break
70-
return torch.stack((*carry, done))
70+
return torch.stack(carry), torch.tensor([done])
7171
else:
7272

7373
def scalar_loop(*args):

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def test_ScalarLoop_while():
401401

402402
op = ScalarLoop(init=[x0], update=[x], until=until)
403403
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
404-
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
405-
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
406-
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
404+
for res, expected in zip(
405+
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
406+
[[10, True], [10, True], [6, False]],
407+
):
408+
np.testing.assert_allclose(res[0], np.array(expected[0]))
409+
np.testing.assert_allclose(res[1], np.array(expected[1]))

0 commit comments

Comments
 (0)