Skip to content

Commit 714759c

Browse files
author
Ian Schweer
committed
Remove unnecessary torch stack
1 parent 39ff3de commit 714759c

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def scalar_loop(steps, *start_and_constants):
8080
*carry, done = update(*carry, *constants)
8181
if done:
8282
break
83-
return torch.stack(carry), torch.tensor([done])
83+
if len(node.outputs) == 2:
84+
return carry[0], done
85+
else:
86+
return carry, done
8487
else:
8588

8689
def scalar_loop(steps, *start_and_constants):
@@ -90,6 +93,9 @@ def scalar_loop(steps, *start_and_constants):
9093
)
9194
for _ in range(steps):
9295
carry = update(*carry, *constants)
93-
return torch.stack(carry)
96+
if len(node.outputs) == 1:
97+
return carry[0]
98+
else:
99+
return carry
94100

95101
return torch.compiler.disable(scalar_loop)

0 commit comments

Comments
 (0)