Skip to content

Commit a0b23cd

Browse files
author
Ian Schweer
committed
Fix while loop and nasty stack over dtypes
1 parent aa00703 commit a0b23cd

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def scalar_loop(steps, *start_and_constants):
5555
*carry, done = update(*carry, *constants)
5656
if done:
5757
break
58-
return torch.stack((*carry, done))
58+
return torch.stack(carry), torch.tensor([done])
5959
else:
6060

6161
def scalar_loop(*args):

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ def test_ScalarLoop_while():
328328

329329
op = ScalarLoop(init=[x0], update=[x], until=until)
330330
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
331-
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
332-
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
333-
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
331+
for res, expected in zip(
332+
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
333+
[[10, True], [10, True], [6, False]],
334+
):
335+
np.testing.assert_allclose(res[0], np.array(expected[0]))
336+
np.testing.assert_allclose(res[1], np.array(expected[1]))

0 commit comments

Comments
 (0)