@@ -54,19 +54,27 @@ def cast(x):
5454@pytorch_funcify .register (ScalarLoop )
5555def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
5656 update = pytorch_funcify (op .fgraph )
57-
58- def inner (steps , start , constant , update = update , is_while = op .is_while ):
59- # easiest way to do it is to loop
60- c = start
61- for i in range (steps ):
62- outs = update (c , constant )
63- if is_while :
64- n , done = outs
57+ if op .is_while :
58+
59+ def scalar_loop (steps , * start_and_constants ):
60+ * carry , constants = start_and_constants
61+ constants = constants .unsqueeze (0 )
62+ done = True
63+ for _ in range (steps ):
64+ * carry , done = update (* carry , * constants )
65+ constants = start_and_constants [len (carry ) :]
6566 if done :
66- return n
67- c = n
68- else :
69- c = outs [0 ]
70- return c
71-
72- return inner
67+ break
68+ return torch .stack ((* carry , done ))
69+ else :
70+
71+ def scalar_loop (* args ):
72+ steps , * start_and_constants = args
73+ * carry , constants = start_and_constants
74+ constants = constants .unsqueeze (0 )
75+ for i in range (steps ):
76+ carry = update (* carry , * constants )
77+ constants = start_and_constants [len (carry ) :]
78+ return torch .stack (carry )
79+
80+ return scalar_loop
0 commit comments