Skip to content

Commit ae1c9da

Browse files
author
Ian Schweer
committed
Fix while loop and nasty stack over dtypes
1 parent e0bbde8 commit ae1c9da

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def scalar_loop(steps, *start_and_constants):
7979
*carry, done = update(*carry, *constants)
8080
if done:
8181
break
82-
return torch.stack((*carry, done))
82+
return torch.stack(carry), torch.tensor([done])
8383
else:
8484

8585
def scalar_loop(*args):

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ def test_ScalarLoop_while():
414414

415415
op = ScalarLoop(init=[x0], update=[x], until=until)
416416
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
417-
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
418-
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
419-
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
417+
for res, expected in zip(
418+
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
419+
[[10, True], [10, True], [6, False]],
420+
):
421+
np.testing.assert_allclose(res[0], np.array(expected[0]))
422+
np.testing.assert_allclose(res[1], np.array(expected[1]))

0 commit comments

Comments
 (0)