Skip to content

Commit 8c64daf

Browse files
author
Ian Schweer
committed
Pass all loop tests
1 parent ec88252 commit 8c64daf

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,27 @@ def cast(x):
5454
@pytorch_funcify.register(ScalarLoop)
5555
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
5656
update = pytorch_funcify(op.fgraph)
57-
58-
def inner(steps, start, constant, update=update, is_while=op.is_while):
59-
# easiest way to do it is to loop
60-
c = start
61-
for i in range(steps):
62-
outs = update(c, constant)
63-
if is_while:
64-
n, done = outs
57+
if op.is_while:
58+
59+
def scalar_loop(steps, *start_and_constants):
60+
*carry, constants = start_and_constants
61+
constants = constants.unsqueeze(0)
62+
done = True
63+
for _ in range(steps):
64+
*carry, done = update(*carry, *constants)
65+
constants = start_and_constants[len(carry) :]
6566
if done:
66-
return n
67-
c = n
68-
else:
69-
c = outs[0]
70-
return c
71-
72-
return inner
67+
break
68+
return torch.stack((*carry, done))
69+
else:
70+
71+
def scalar_loop(*args):
72+
steps, *start_and_constants = args
73+
*carry, constants = start_and_constants
74+
constants = constants.unsqueeze(0)
75+
for i in range(steps):
76+
carry = update(*carry, *constants)
77+
constants = start_and_constants[len(carry) :]
78+
return torch.stack(carry)
79+
80+
return scalar_loop

0 commit comments

Comments
 (0)