We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 39ff3de commit 714759cCopy full SHA for 714759c
pytensor/link/pytorch/dispatch/scalar.py
@@ -80,7 +80,10 @@ def scalar_loop(steps, *start_and_constants):
80
*carry, done = update(*carry, *constants)
81
if done:
82
break
83
- return torch.stack(carry), torch.tensor([done])
+ if len(node.outputs) == 2:
84
+ return carry[0], done
85
+ else:
86
+ return carry, done
87
else:
88
89
def scalar_loop(steps, *start_and_constants):
@@ -90,6 +93,9 @@ def scalar_loop(steps, *start_and_constants):
90
93
)
91
94
for _ in range(steps):
92
95
carry = update(*carry, *constants)
- return torch.stack(carry)
96
+ if len(node.outputs) == 1:
97
+ return carry[0]
98
99
+ return carry
100
101
return torch.compiler.disable(scalar_loop)
0 commit comments