@@ -54,27 +54,30 @@ def cast(x):
5454@pytorch_funcify .register (ScalarLoop )
5555def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
5656 update = pytorch_funcify (op .fgraph )
57+ state_length = op .nout
5758 if op .is_while :
5859
5960 def scalar_loop (steps , * start_and_constants ):
60- * carry , constants = start_and_constants
61- constants = constants .unsqueeze (0 )
61+ carry , constants = (
62+ start_and_constants [:state_length ],
63+ start_and_constants [state_length :],
64+ )
6265 done = True
6366 for _ in range (steps ):
6467 * carry , done = update (* carry , * constants )
65- constants = start_and_constants [len (carry ) :]
6668 if done :
6769 break
6870 return torch .stack ((* carry , done ))
6971 else :
7072
7173 def scalar_loop (* args ):
7274 steps , * start_and_constants = args
73- * carry , constants = start_and_constants
74- constants = constants .unsqueeze (0 )
75- for i in range (steps ):
75+ carry , constants = (
76+ start_and_constants [:state_length ],
77+ start_and_constants [state_length :],
78+ )
79+ for _ in range (steps ):
7680 carry = update (* carry , * constants )
77- constants = start_and_constants [len (carry ) :]
7881 return torch .stack (carry )
7982
8083 return scalar_loop
0 commit comments