@@ -42,19 +42,27 @@ def pytorch_func(*args):
4242@pytorch_funcify .register (ScalarLoop )
4343def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
4444 update = pytorch_funcify (op .fgraph )
45+ if op .is_while :
4546
46- def inner (steps , start , constant , update = update , is_while = op . is_while ):
47- # easiest way to do it is to loop
48- c = start
49- for i in range ( steps ):
50- outs = update ( c , constant )
51- if is_while :
52- n , done = outs
47+ def scalar_loop (steps , * start_and_constants ):
48+ * carry , constants = start_and_constants
49+ constants = constants . unsqueeze ( 0 )
50+ done = True
51+ for _ in range ( steps ):
52+ * carry , done = update ( * carry , * constants )
53+ constants = start_and_constants [ len ( carry ) :]
5354 if done :
54- return n
55- c = n
56- else :
57- c = outs [0 ]
58- return c
55+ break
56+ return torch .stack ((* carry , done ))
57+ else :
5958
60- return inner
59+ def scalar_loop (* args ):
60+ steps , * start_and_constants = args
61+ * carry , constants = start_and_constants
62+ constants = constants .unsqueeze (0 )
63+ for i in range (steps ):
64+ carry = update (* carry , * constants )
65+ constants = start_and_constants [len (carry ) :]
66+ return torch .stack (carry )
67+
68+ return scalar_loop
0 commit comments