@@ -66,27 +66,30 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
6666@pytorch_funcify .register (ScalarLoop )
6767def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
6868 update = pytorch_funcify (op .fgraph )
69+ state_length = op .nout
6970 if op .is_while :
7071
7172 def scalar_loop (steps , * start_and_constants ):
72- * carry , constants = start_and_constants
73- constants = constants .unsqueeze (0 )
73+ carry , constants = (
74+ start_and_constants [:state_length ],
75+ start_and_constants [state_length :],
76+ )
7477 done = True
7578 for _ in range (steps ):
7679 * carry , done = update (* carry , * constants )
77- constants = start_and_constants [len (carry ) :]
7880 if done :
7981 break
8082 return torch .stack ((* carry , done ))
8183 else :
8284
8385 def scalar_loop (* args ):
8486 steps , * start_and_constants = args
85- * carry , constants = start_and_constants
86- constants = constants .unsqueeze (0 )
87- for i in range (steps ):
87+ carry , constants = (
88+ start_and_constants [:state_length ],
89+ start_and_constants [state_length :],
90+ )
91+ for _ in range (steps ):
8892 carry = update (* carry , * constants )
89- constants = start_and_constants [len (carry ) :]
9093 return torch .stack (carry )
9194
9295 return scalar_loop
0 commit comments