Skip to content

Commit cdc06f9

Browse files
author
Ian Schweer
committed
Fetch constants from op
1 parent a3bb433 commit cdc06f9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,30 @@ def pytorch_func(*args):
4242
@pytorch_funcify.register(ScalarLoop)
4343
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
4444
update = pytorch_funcify(op.fgraph)
45+
state_length = op.nout
4546
if op.is_while:
4647

4748
def scalar_loop(steps, *start_and_constants):
48-
*carry, constants = start_and_constants
49-
constants = constants.unsqueeze(0)
49+
carry, constants = (
50+
start_and_constants[:state_length],
51+
start_and_constants[state_length:],
52+
)
5053
done = True
5154
for _ in range(steps):
5255
*carry, done = update(*carry, *constants)
53-
constants = start_and_constants[len(carry) :]
5456
if done:
5557
break
5658
return torch.stack((*carry, done))
5759
else:
5860

5961
def scalar_loop(*args):
6062
steps, *start_and_constants = args
61-
*carry, constants = start_and_constants
62-
constants = constants.unsqueeze(0)
63-
for i in range(steps):
63+
carry, constants = (
64+
start_and_constants[:state_length],
65+
start_and_constants[state_length:],
66+
)
67+
for _ in range(steps):
6468
carry = update(*carry, *constants)
65-
constants = start_and_constants[len(carry) :]
6669
return torch.stack(carry)
6770

6871
return scalar_loop

0 commit comments

Comments
 (0)