@@ -42,27 +42,30 @@ def pytorch_func(*args):
4242@pytorch_funcify .register (ScalarLoop )
4343def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
4444 update = pytorch_funcify (op .fgraph )
45+ state_length = op .nout
4546 if op .is_while :
4647
4748 def scalar_loop (steps , * start_and_constants ):
48- * carry , constants = start_and_constants
49- constants = constants .unsqueeze (0 )
49+ carry , constants = (
50+ start_and_constants [:state_length ],
51+ start_and_constants [state_length :],
52+ )
5053 done = True
5154 for _ in range (steps ):
5255 * carry , done = update (* carry , * constants )
53- constants = start_and_constants [len (carry ) :]
5456 if done :
5557 break
5658 return torch .stack ((* carry , done ))
5759 else :
5860
5961 def scalar_loop (* args ):
6062 steps , * start_and_constants = args
61- * carry , constants = start_and_constants
62- constants = constants .unsqueeze (0 )
63- for i in range (steps ):
63+ carry , constants = (
64+ start_and_constants [:state_length ],
65+ start_and_constants [state_length :],
66+ )
67+ for _ in range (steps ):
6468 carry = update (* carry , * constants )
65- constants = start_and_constants [len (carry ) :]
6669 return torch .stack (carry )
6770
6871 return scalar_loop
0 commit comments