Skip to content

Commit ebaf641

Browse files
author
Ian Schweer
committed
Remove unnecessary torch stack
1 parent e4c2b9d commit ebaf641

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def scalar_loop(steps, *start_and_constants):
5656
*carry, done = update(*carry, *constants)
5757
if done:
5858
break
59-
return torch.stack(carry), torch.tensor([done])
59+
if len(node.outputs) == 2:
60+
return carry[0], done
61+
else:
62+
return carry, done
6063
else:
6164

6265
def scalar_loop(steps, *start_and_constants):
@@ -66,6 +69,9 @@ def scalar_loop(steps, *start_and_constants):
6669
)
6770
for _ in range(steps):
6871
carry = update(*carry, *constants)
69-
return torch.stack(carry)
72+
if len(node.outputs) == 1:
73+
return carry[0]
74+
else:
75+
return carry
7076

7177
return torch.compiler.disable(scalar_loop)

0 commit comments

Comments
 (0)