We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9c1d897 commit 6a1b319Copy full SHA for 6a1b319
pytensor/link/pytorch/dispatch/scalar.py
@@ -68,7 +68,10 @@ def scalar_loop(steps, *start_and_constants):
68
*carry, done = update(*carry, *constants)
69
if done:
70
break
71
- return torch.stack(carry), torch.tensor([done])
+ if len(node.outputs) == 2:
72
+ return carry[0], done
73
+ else:
74
+ return carry, done
75
else:
76
77
def scalar_loop(steps, *start_and_constants):
@@ -78,6 +81,9 @@ def scalar_loop(steps, *start_and_constants):
78
81
)
79
82
for _ in range(steps):
80
83
carry = update(*carry, *constants)
- return torch.stack(carry)
84
+ if len(node.outputs) == 1:
85
+ return carry[0]
86
87
+ return carry
88
89
return torch.compiler.disable(scalar_loop)
0 commit comments