Skip to content

Commit 6a1b319

Browse files
author
Ian Schweer
committed
Remove unnecessary torch stack
1 parent 9c1d897 commit 6a1b319

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def scalar_loop(steps, *start_and_constants):
6868
*carry, done = update(*carry, *constants)
6969
if done:
7070
break
71-
return torch.stack(carry), torch.tensor([done])
71+
if len(node.outputs) == 2:
72+
return carry[0], done
73+
else:
74+
return carry, done
7275
else:
7376

7477
def scalar_loop(steps, *start_and_constants):
@@ -78,6 +81,9 @@ def scalar_loop(steps, *start_and_constants):
7881
)
7982
for _ in range(steps):
8083
carry = update(*carry, *constants)
81-
return torch.stack(carry)
84+
if len(node.outputs) == 1:
85+
return carry[0]
86+
else:
87+
return carry
8288

8389
return torch.compiler.disable(scalar_loop)

0 commit comments

Comments
 (0)