Skip to content

Commit 12569f8

Browse files
author
Ian Schweer
committed
Pass all loop tests
1 parent 7804b90 commit 12569f8

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,27 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
6666
@pytorch_funcify.register(ScalarLoop)
6767
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
6868
update = pytorch_funcify(op.fgraph)
69-
70-
def inner(steps, start, constant, update=update, is_while=op.is_while):
71-
# easiest way to do it is to loop
72-
c = start
73-
for i in range(steps):
74-
outs = update(c, constant)
75-
if is_while:
76-
n, done = outs
69+
if op.is_while:
70+
71+
def scalar_loop(steps, *start_and_constants):
72+
*carry, constants = start_and_constants
73+
constants = constants.unsqueeze(0)
74+
done = True
75+
for _ in range(steps):
76+
*carry, done = update(*carry, *constants)
77+
constants = start_and_constants[len(carry) :]
7778
if done:
78-
return n
79-
c = n
80-
else:
81-
c = outs[0]
82-
return c
79+
break
80+
return torch.stack((*carry, done))
81+
else:
82+
83+
def scalar_loop(*args):
84+
steps, *start_and_constants = args
85+
*carry, constants = start_and_constants
86+
constants = constants.unsqueeze(0)
87+
for i in range(steps):
88+
carry = update(*carry, *constants)
89+
constants = start_and_constants[len(carry) :]
90+
return torch.stack(carry)
8391

84-
return inner
92+
return scalar_loop

0 commit comments

Comments
 (0)