@@ -66,19 +66,27 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
6666@pytorch_funcify .register (ScalarLoop )
6767def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
6868 update = pytorch_funcify (op .fgraph )
69-
70- def inner (steps , start , constant , update = update , is_while = op .is_while ):
71- # easiest way to do it is to loop
72- c = start
73- for i in range (steps ):
74- outs = update (c , constant )
75- if is_while :
76- n , done = outs
69+ if op .is_while :
70+
71+ def scalar_loop (steps , * start_and_constants ):
72+ * carry , constants = start_and_constants
73+ constants = constants .unsqueeze (0 )
74+ done = True
75+ for _ in range (steps ):
76+ * carry , done = update (* carry , * constants )
77+ constants = start_and_constants [len (carry ) :]
7778 if done :
78- return n
79- c = n
80- else :
81- c = outs [0 ]
82- return c
79+ break
80+ return torch .stack ((* carry , done ))
81+ else :
82+
83+ def scalar_loop (* args ):
84+ steps , * start_and_constants = args
85+ * carry , constants = start_and_constants
86+ constants = constants .unsqueeze (0 )
87+ for i in range (steps ):
88+ carry = update (* carry , * constants )
89+ constants = start_and_constants [len (carry ) :]
90+ return torch .stack (carry )
8391
84- return inner
92+ return scalar_loop
0 commit comments