Skip to content

Commit a3bb433

Browse files
author
Ian Schweer
committed
Pass all loop tests
1 parent 5934e09 commit a3bb433

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,27 @@ def pytorch_func(*args):
4242
@pytorch_funcify.register(ScalarLoop)
4343
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
4444
update = pytorch_funcify(op.fgraph)
45+
if op.is_while:
4546

46-
def inner(steps, start, constant, update=update, is_while=op.is_while):
47-
# easiest way to do it is to loop
48-
c = start
49-
for i in range(steps):
50-
outs = update(c, constant)
51-
if is_while:
52-
n, done = outs
47+
def scalar_loop(steps, *start_and_constants):
48+
*carry, constants = start_and_constants
49+
constants = constants.unsqueeze(0)
50+
done = True
51+
for _ in range(steps):
52+
*carry, done = update(*carry, *constants)
53+
constants = start_and_constants[len(carry) :]
5354
if done:
54-
return n
55-
c = n
56-
else:
57-
c = outs[0]
58-
return c
55+
break
56+
return torch.stack((*carry, done))
57+
else:
5958

60-
return inner
59+
def scalar_loop(*args):
60+
steps, *start_and_constants = args
61+
*carry, constants = start_and_constants
62+
constants = constants.unsqueeze(0)
63+
for i in range(steps):
64+
carry = update(*carry, *constants)
65+
constants = start_and_constants[len(carry) :]
66+
return torch.stack(carry)
67+
68+
return scalar_loop

0 commit comments

Comments
 (0)