@@ -13,11 +13,6 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1313 if info .as_while :
1414 raise NotImplementedError ("While Scan cannot yet be converted to JAX" )
1515
16- if info .n_mit_mot :
17- raise NotImplementedError (
18- "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
19- )
20-
2116 # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
2217 rewriter = op .mode_instance .excluding (* JAX ._optimizer .exclude ).optimizer
2318 rewriter (op .fgraph )
@@ -29,50 +24,74 @@ def scan(*outer_inputs):
2924 n_steps = outer_inputs [0 ] # JAX `length`
3025 seqs = op .outer_seqs (outer_inputs ) # JAX `xs`
3126
32- mit_sot_init = []
33- for tap , seq in zip (
27+ # MIT-MOT and MIT-SOT are provided from outside as a tape long enough to store the initial values and intermediate outputs
28+ # To bootstrap the inner function we need to slice the initial values
29+ mit_mot_inits = []
30+ for taps , seq in zip (
31+ op .info .mit_mot_in_slices , op .outer_mitmot (outer_inputs ), strict = True
32+ ):
33+ # mit-mot taps are non-negative
34+ init_slice = seq [: max (taps ) + 1 ]
35+ mit_mot_inits .append (init_slice )
36+
37+ mit_sot_inits = []
38+ for taps , seq in zip (
3439 op .info .mit_sot_in_slices , op .outer_mitsot (outer_inputs ), strict = True
3540 ):
36- init_slice = seq [: abs (min (tap ))]
37- mit_sot_init .append (init_slice )
41+ # mit-sot taps are negative
42+ init_slice = seq [: abs (min (taps ))]
43+ mit_sot_inits .append (init_slice )
3844
39- sit_sot_init = [seq [0 ] for seq in op .outer_sitsot (outer_inputs )]
45+ sit_sot_inits = [seq [0 ] for seq in op .outer_sitsot (outer_inputs )]
4046
4147 init_carry = (
42- mit_sot_init ,
43- sit_sot_init ,
48+ mit_mot_inits ,
49+ mit_sot_inits ,
50+ sit_sot_inits ,
4451 op .outer_shared (outer_inputs ),
4552 op .outer_non_seqs (outer_inputs ),
4653 ) # JAX `init`
4754
4855 def jax_args_to_inner_func_args (carry , x ):
4956 """Convert JAX scan arguments into format expected by scan_inner_func.
5057
51- scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
58+ scan(carry, x) -> scan_inner_func(seqs, mit_mot, mit_sot, sit_sot, shared, non_seqs)
5259 """
5360
5461 # `carry` contains all inner taps, shared terms, and non_seqs
5562 (
56- inner_mit_sot ,
57- inner_sit_sot ,
58- inner_shared ,
63+ inner_mit_mots ,
64+ inner_mit_sots ,
65+ inner_sit_sots ,
66+ inner_shareds ,
5967 inner_non_seqs ,
6068 ) = carry
6169
6270 # `x` contains the inner sequences
6371 inner_seqs = x
6472
65- mit_sot_flatten = []
66- for array , index in zip (
67- inner_mit_sot , op .info .mit_sot_in_slices , strict = True
73+ # MIT-MOT and MIT-SOT are provided as unified tensors and should be split
74+ # into distinct entries for the inner function
75+ split_mit_mots = []
76+ for taps , seq in zip (
77+ op .info .mit_mot_in_slices , inner_mit_mots , strict = True
78+ ):
79+ for tap in taps :
80+ split_mit_mots .append (seq [tap ])
81+
82+ split_mit_sots = []
83+ for taps , seq in zip (
84+ op .info .mit_sot_in_slices , inner_mit_sots , strict = True
6885 ):
69- mit_sot_flatten .extend (array [jnp .array (index )])
86+ for tap in taps :
87+ split_mit_sots .append (seq [tap ])
7088
7189 inner_scan_inputs = [
7290 * inner_seqs ,
73- * mit_sot_flatten ,
74- * inner_sit_sot ,
75- * inner_shared ,
91+ * split_mit_mots , # TODO: Confirm oreding
92+ * split_mit_sots ,
93+ * inner_sit_sots ,
94+ * inner_shareds ,
7695 * inner_non_seqs ,
7796 ]
7897
@@ -84,44 +103,71 @@ def inner_func_outs_to_jax_outs(
84103 ):
85104 """Convert inner_scan_func outputs into format expected by JAX scan.
86105
87- old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
106+ old_carry + (mit_mot_outs, mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
88107 """
89108 (
90- inner_mit_sot ,
91- inner_sit_sot ,
92- inner_shared ,
109+ inner_mit_mots ,
110+ inner_mit_sots ,
111+ inner_sit_sots ,
112+ inner_shareds ,
93113 inner_non_seqs ,
94114 ) = old_carry
95115
116+ inner_mit_mot_outs = op .inner_mitmot_outs (inner_scan_outs )
96117 inner_mit_sot_outs = op .inner_mitsot_outs (inner_scan_outs )
97118 inner_sit_sot_outs = op .inner_sitsot_outs (inner_scan_outs )
98119 inner_nit_sot_outs = op .inner_nitsot_outs (inner_scan_outs )
99120 inner_shared_outs = op .inner_shared_outs (inner_scan_outs )
100121
101- # Replace the oldest mit_sot tap by the newest value
102- inner_mit_sot_new = [
103- jnp .concatenate ([old_mit_sot [1 :], new_val [None , ...]], axis = 0 )
104- for old_mit_sot , new_val in zip (
105- inner_mit_sot , inner_mit_sot_outs , strict = True
122+ # Group split mit_mot_outs into the respective groups
123+ start = 0
124+ grouped_inner_mit_mot_outs = []
125+ for mit_mot_out_slice in op .info .mit_mot_out_slices :
126+ end = start + len (mit_mot_out_slice )
127+ elements = inner_mit_mot_outs [start :end ]
128+ group = jnp .concatenate ([e [None ] for e in elements ], axis = 0 )
129+ grouped_inner_mit_mot_outs .append (group )
130+ start = end
131+
132+ # Replace the oldest mit-mot taps (last entries) and prepend the newest values
133+ new_inner_mit_mots = []
134+ for old_mit_mot , new_outs in zip (
135+ inner_mit_mots , grouped_inner_mit_mot_outs , strict = True
136+ ):
137+ n_outs = len (new_outs )
138+ inner_mit_mot_new = jnp .concatenate (
139+ [old_mit_mot [n_outs :], group ], axis = 0
106140 )
107- ]
141+ new_inner_mit_mots .append (inner_mit_mot_new )
142+
143+ # Drop the oldest mit-sot tap (first entry) and append the newest value at end
144+ new_inner_mit_sots = []
145+ for old_mit_sot , new_out in zip (
146+ inner_mit_sots , inner_mit_sot_outs , strict = True
147+ ):
148+ inner_mit_sot_new = jnp .concatenate (
149+ [old_mit_sot [1 :], new_out [None , ...]], axis = 0
150+ )
151+ new_inner_mit_mots .append (inner_mit_sot_new )
108152
109153 # Nothing needs to be done with sit_sot
110- inner_sit_sot_new = inner_sit_sot_outs
154+ new_inner_sit_sots = inner_sit_sot_outs
111155
112- inner_shared_new = inner_shared
156+ new_inner_shareds = inner_shareds
113157 # Replace old shared inputs by new shared outputs
114- inner_shared_new [: len (inner_shared_outs )] = inner_shared_outs
158+ new_inner_shareds [: len (inner_shared_outs )] = inner_shared_outs
115159
116160 new_carry = (
117- inner_mit_sot_new ,
118- inner_sit_sot_new ,
119- inner_shared_new ,
161+ new_inner_mit_mots ,
162+ new_inner_mit_sots ,
163+ new_inner_sit_sots ,
164+ new_inner_shareds ,
120165 inner_non_seqs ,
121166 )
122167
123168 # Shared variables and non_seqs are not traced
124169 traced_outs = [
170+ * grouped_inner_mit_mot_outs ,
125171 * inner_mit_sot_outs ,
126172 * inner_sit_sot_outs ,
127173 * inner_nit_sot_outs ,
@@ -148,9 +194,15 @@ def get_partial_traces(traces):
148194 2. Slice final traces if Scan was instructed to only keep a portion
149195 """
150196
151- init_states = mit_sot_init + sit_sot_init + [None ] * op .info .n_nit_sot
197+ init_states = (
198+ mit_mot_inits
199+ + mit_sot_inits
200+ + sit_sot_inits
201+ + [None ] * op .info .n_nit_sot
202+ )
152203 buffers = (
153- op .outer_mitsot (outer_inputs )
204+ op .outer_mitmot (outer_inputs )
205+ + op .outer_mitsot (outer_inputs )
154206 + op .outer_sitsot (outer_inputs )
155207 + op .outer_nitsot (outer_inputs )
156208 )
@@ -159,11 +211,10 @@ def get_partial_traces(traces):
159211 init_states , traces , buffers , strict = True
160212 ):
161213 if init_state is not None :
162- # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
214+ # MIT-MOT, MIT- SOT and SIT-SOT: The final output should be as long as the input buffer
163215 trace = jnp .atleast_1d (trace )
164- init_state = jnp .expand_dims (
165- init_state , range (trace .ndim - init_state .ndim )
166- )
216+ init_state = jnp .expand_dims (init_state , 1 )
217+ # TODO: delete this, shouldn't be needed?
167218 full_trace = jnp .concatenate ([init_state , trace ], axis = 0 )
168219 buffer_size = buffer .shape [0 ]
169220 else :
0 commit comments