Skip to content

Commit 03f46a6

Browse files
author
Ian Schweer
committed
Fetch constants from op
1 parent 8c64daf commit 03f46a6

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,30 @@ def cast(x):
5454
@pytorch_funcify.register(ScalarLoop)
5555
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
5656
update = pytorch_funcify(op.fgraph)
57+
state_length = op.nout
5758
if op.is_while:
5859

5960
def scalar_loop(steps, *start_and_constants):
60-
*carry, constants = start_and_constants
61-
constants = constants.unsqueeze(0)
61+
carry, constants = (
62+
start_and_constants[:state_length],
63+
start_and_constants[state_length:],
64+
)
6265
done = True
6366
for _ in range(steps):
6467
*carry, done = update(*carry, *constants)
65-
constants = start_and_constants[len(carry) :]
6668
if done:
6769
break
6870
return torch.stack((*carry, done))
6971
else:
7072

7173
def scalar_loop(*args):
7274
steps, *start_and_constants = args
73-
*carry, constants = start_and_constants
74-
constants = constants.unsqueeze(0)
75-
for i in range(steps):
75+
carry, constants = (
76+
start_and_constants[:state_length],
77+
start_and_constants[state_length:],
78+
)
79+
for _ in range(steps):
7680
carry = update(*carry, *constants)
77-
constants = start_and_constants[len(carry) :]
7881
return torch.stack(carry)
7982

8083
return scalar_loop

0 commit comments

Comments
 (0)