We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e4c2b9d commit ebaf641Copy full SHA for ebaf641
pytensor/link/pytorch/dispatch/scalar.py
@@ -56,7 +56,10 @@ def scalar_loop(steps, *start_and_constants):
56
*carry, done = update(*carry, *constants)
57
if done:
58
break
59
- return torch.stack(carry), torch.tensor([done])
+ if len(node.outputs) == 2:
60
+ return carry[0], done
61
+ else:
62
+ return carry, done
63
else:
64
65
def scalar_loop(steps, *start_and_constants):
@@ -66,6 +69,9 @@ def scalar_loop(steps, *start_and_constants):
66
69
)
67
70
for _ in range(steps):
68
71
carry = update(*carry, *constants)
- return torch.stack(carry)
72
+ if len(node.outputs) == 1:
73
+ return carry[0]
74
75
+ return carry
76
77
return torch.compiler.disable(scalar_loop)
0 commit comments