Skip to content

Commit 89b0b91

Browse files
committed
.wip
1 parent 5d4e9e0 commit 89b0b91

File tree

2 files changed

+101
-49
lines changed

2 files changed

+101
-49
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1313
if info.as_while:
1414
raise NotImplementedError("While Scan cannot yet be converted to JAX")
1515

16-
if info.n_mit_mot:
17-
raise NotImplementedError(
18-
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
19-
)
20-
2116
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
2217
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
2318
rewriter(op.fgraph)
@@ -29,50 +24,74 @@ def scan(*outer_inputs):
2924
n_steps = outer_inputs[0] # JAX `length`
3025
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
3126

32-
mit_sot_init = []
33-
for tap, seq in zip(
27+
# MIT-MOT and MIT-SOT are provided from outside as a tape long enough to store the initial values and intermediate outputs
28+
# To bootstrap the inner function we need to slice the initial values
29+
mit_mot_inits = []
30+
for taps, seq in zip(
31+
op.info.mit_mot_in_slices, op.outer_mitmot(outer_inputs), strict=True
32+
):
33+
# mit-mot taps are non-negative
34+
init_slice = seq[: max(taps) + 1]
35+
mit_mot_inits.append(init_slice)
36+
37+
mit_sot_inits = []
38+
for taps, seq in zip(
3439
op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True
3540
):
36-
init_slice = seq[: abs(min(tap))]
37-
mit_sot_init.append(init_slice)
41+
# mit-sot taps are negative
42+
init_slice = seq[: abs(min(taps))]
43+
mit_sot_inits.append(init_slice)
3844

39-
sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
45+
sit_sot_inits = [seq[0] for seq in op.outer_sitsot(outer_inputs)]
4046

4147
init_carry = (
42-
mit_sot_init,
43-
sit_sot_init,
48+
mit_mot_inits,
49+
mit_sot_inits,
50+
sit_sot_inits,
4451
op.outer_shared(outer_inputs),
4552
op.outer_non_seqs(outer_inputs),
4653
) # JAX `init`
4754

4855
def jax_args_to_inner_func_args(carry, x):
4956
"""Convert JAX scan arguments into format expected by scan_inner_func.
5057
51-
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
58+
scan(carry, x) -> scan_inner_func(seqs, mit_mot, mit_sot, sit_sot, shared, non_seqs)
5259
"""
5360

5461
# `carry` contains all inner taps, shared terms, and non_seqs
5562
(
56-
inner_mit_sot,
57-
inner_sit_sot,
58-
inner_shared,
63+
inner_mit_mots,
64+
inner_mit_sots,
65+
inner_sit_sots,
66+
inner_shareds,
5967
inner_non_seqs,
6068
) = carry
6169

6270
# `x` contains the inner sequences
6371
inner_seqs = x
6472

65-
mit_sot_flatten = []
66-
for array, index in zip(
67-
inner_mit_sot, op.info.mit_sot_in_slices, strict=True
73+
# MIT-MOT and MIT-SOT are provided as unified tensors and should be split
74+
# into distinct entries for the inner function
75+
split_mit_mots = []
76+
for taps, seq in zip(
77+
op.info.mit_mot_in_slices, inner_mit_mots, strict=True
78+
):
79+
for tap in taps:
80+
split_mit_mots.append(seq[tap])
81+
82+
split_mit_sots = []
83+
for taps, seq in zip(
84+
op.info.mit_sot_in_slices, inner_mit_sots, strict=True
6885
):
69-
mit_sot_flatten.extend(array[jnp.array(index)])
86+
for tap in taps:
87+
split_mit_sots.append(seq[tap])
7088

7189
inner_scan_inputs = [
7290
*inner_seqs,
73-
*mit_sot_flatten,
74-
*inner_sit_sot,
75-
*inner_shared,
91+
*split_mit_mots, # TODO: Confirm oreding
92+
*split_mit_sots,
93+
*inner_sit_sots,
94+
*inner_shareds,
7695
*inner_non_seqs,
7796
]
7897

@@ -84,44 +103,71 @@ def inner_func_outs_to_jax_outs(
84103
):
85104
"""Convert inner_scan_func outputs into format expected by JAX scan.
86105
87-
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
106+
old_carry + (mit_mot_outs, mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
88107
"""
89108
(
90-
inner_mit_sot,
91-
inner_sit_sot,
92-
inner_shared,
109+
inner_mit_mots,
110+
inner_mit_sots,
111+
inner_sit_sots,
112+
inner_shareds,
93113
inner_non_seqs,
94114
) = old_carry
95115

116+
inner_mit_mot_outs = op.inner_mitmot_outs(inner_scan_outs)
96117
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
97118
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
98119
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
99120
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)
100121

101-
# Replace the oldest mit_sot tap by the newest value
102-
inner_mit_sot_new = [
103-
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
104-
for old_mit_sot, new_val in zip(
105-
inner_mit_sot, inner_mit_sot_outs, strict=True
122+
# Group split mit_mot_outs into the respective groups
123+
start = 0
124+
grouped_inner_mit_mot_outs = []
125+
for mit_mot_out_slice in op.info.mit_mot_out_slices:
126+
end = start + len(mit_mot_out_slice)
127+
elements = inner_mit_mot_outs[start:end]
128+
group = jnp.concatenate([e[None] for e in elements], axis=0)
129+
grouped_inner_mit_mot_outs.append(group)
130+
start = end
131+
132+
# Replace the oldest mit-mot taps (last entries) and prepend the newest values
133+
new_inner_mit_mots = []
134+
for old_mit_mot, new_outs in zip(
135+
inner_mit_mots, grouped_inner_mit_mot_outs, strict=True
136+
):
137+
n_outs = len(new_outs)
138+
inner_mit_mot_new = jnp.concatenate(
139+
[old_mit_mot[n_outs:], group], axis=0
106140
)
107-
]
141+
new_inner_mit_mots.append(inner_mit_mot_new)
142+
143+
# Drop the oldest mit-sot tap (first entry) and append the newest value at end
144+
new_inner_mit_sots = []
145+
for old_mit_sot, new_out in zip(
146+
inner_mit_sots, inner_mit_sot_outs, strict=True
147+
):
148+
inner_mit_sot_new = jnp.concatenate(
149+
[old_mit_sot[1:], new_out[None, ...]], axis=0
150+
)
151+
new_inner_mit_mots.append(inner_mit_sot_new)
108152

109153
# Nothing needs to be done with sit_sot
110-
inner_sit_sot_new = inner_sit_sot_outs
154+
new_inner_sit_sots = inner_sit_sot_outs
111155

112-
inner_shared_new = inner_shared
156+
new_inner_shareds = inner_shareds
113157
# Replace old shared inputs by new shared outputs
114-
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs
158+
new_inner_shareds[: len(inner_shared_outs)] = inner_shared_outs
115159

116160
new_carry = (
117-
inner_mit_sot_new,
118-
inner_sit_sot_new,
119-
inner_shared_new,
161+
new_inner_mit_mots,
162+
new_inner_mit_sots,
163+
new_inner_sit_sots,
164+
new_inner_shareds,
120165
inner_non_seqs,
121166
)
122167

123168
# Shared variables and non_seqs are not traced
124169
traced_outs = [
170+
*grouped_inner_mit_mot_outs,
125171
*inner_mit_sot_outs,
126172
*inner_sit_sot_outs,
127173
*inner_nit_sot_outs,
@@ -148,9 +194,15 @@ def get_partial_traces(traces):
148194
2. Slice final traces if Scan was instructed to only keep a portion
149195
"""
150196

151-
init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
197+
init_states = (
198+
mit_mot_inits
199+
+ mit_sot_inits
200+
+ sit_sot_inits
201+
+ [None] * op.info.n_nit_sot
202+
)
152203
buffers = (
153-
op.outer_mitsot(outer_inputs)
204+
op.outer_mitmot(outer_inputs)
205+
+ op.outer_mitsot(outer_inputs)
154206
+ op.outer_sitsot(outer_inputs)
155207
+ op.outer_nitsot(outer_inputs)
156208
)
@@ -159,11 +211,10 @@ def get_partial_traces(traces):
159211
init_states, traces, buffers, strict=True
160212
):
161213
if init_state is not None:
162-
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
214+
# MIT-MOT, MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
163215
trace = jnp.atleast_1d(trace)
164-
init_state = jnp.expand_dims(
165-
init_state, range(trace.ndim - init_state.ndim)
166-
)
216+
init_state = jnp.expand_dims(init_state, 1)
217+
# TODO: delete this, shouldn't be needed?
167218
full_trace = jnp.concatenate([init_state, trace], axis=0)
168219
buffer_size = buffer.shape[0]
169220
else:

tests/link/jax/test_scan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,17 @@ def test_scan_nit_sot(view):
9898
assert len(scan_nodes) == 1
9999

100100

101-
@pytest.mark.xfail(raises=NotImplementedError)
102101
def test_scan_mit_mot():
103-
xs = pt.vector("xs", shape=(10,))
102+
xs = pt.tensor("xs", shape=(2, 2))
104103
ys, _ = scan(
105104
lambda xtm2, xtm1: (xtm2 + xtm1),
106105
outputs_info=[{"initial": xs, "taps": [-2, -1]}],
107-
n_steps=10,
106+
n_steps=4,
108107
)
109108
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
110-
compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
109+
f = function([xs], grads_wrt_xs, mode="JAX")
110+
f(np.arange(4).reshape((2, 2)))
111+
# compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(2)])
111112

112113

113114
def test_scan_update():

0 commit comments

Comments
 (0)