Skip to content

Commit 8eff3fe

Browse files
author
Ian Schweer
committed
Fetch constants from op
1 parent 12569f8 commit 8eff3fe

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,30 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
6666
@pytorch_funcify.register(ScalarLoop)
6767
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
6868
update = pytorch_funcify(op.fgraph)
69+
state_length = op.nout
6970
if op.is_while:
7071

7172
def scalar_loop(steps, *start_and_constants):
72-
*carry, constants = start_and_constants
73-
constants = constants.unsqueeze(0)
73+
carry, constants = (
74+
start_and_constants[:state_length],
75+
start_and_constants[state_length:],
76+
)
7477
done = True
7578
for _ in range(steps):
7679
*carry, done = update(*carry, *constants)
77-
constants = start_and_constants[len(carry) :]
7880
if done:
7981
break
8082
return torch.stack((*carry, done))
8183
else:
8284

8385
def scalar_loop(*args):
8486
steps, *start_and_constants = args
85-
*carry, constants = start_and_constants
86-
constants = constants.unsqueeze(0)
87-
for i in range(steps):
87+
carry, constants = (
88+
start_and_constants[:state_length],
89+
start_and_constants[state_length:],
90+
)
91+
for _ in range(steps):
8892
carry = update(*carry, *constants)
89-
constants = start_and_constants[len(carry) :]
9093
return torch.stack(carry)
9194

9295
return scalar_loop

0 commit comments

Comments
 (0)