@@ -42,27 +42,30 @@ def pytorch_func(*args):
42
42
@pytorch_funcify .register (ScalarLoop )
43
43
def pytorch_funicify_ScalarLoop (op , node , ** kwargs ):
44
44
update = pytorch_funcify (op .fgraph )
45
+ state_length = op .nout
45
46
if op .is_while :
46
47
47
48
def scalar_loop (steps , * start_and_constants ):
48
- * carry , constants = start_and_constants
49
- constants = constants .unsqueeze (0 )
49
+ carry , constants = (
50
+ start_and_constants [:state_length ],
51
+ start_and_constants [state_length :],
52
+ )
50
53
done = True
51
54
for _ in range (steps ):
52
55
* carry , done = update (* carry , * constants )
53
- constants = start_and_constants [len (carry ) :]
54
56
if done :
55
57
break
56
58
return torch .stack ((* carry , done ))
57
59
else :
58
60
59
61
def scalar_loop (* args ):
60
62
steps , * start_and_constants = args
61
- * carry , constants = start_and_constants
62
- constants = constants .unsqueeze (0 )
63
- for i in range (steps ):
63
+ carry , constants = (
64
+ start_and_constants [:state_length ],
65
+ start_and_constants [state_length :],
66
+ )
67
+ for _ in range (steps ):
64
68
carry = update (* carry , * constants )
65
- constants = start_and_constants [len (carry ) :]
66
69
return torch .stack (carry )
67
70
68
71
return scalar_loop
0 commit comments