Skip to content

Commit 9779797

Browse files
Benchmark scan in JAX backend
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 545e58f commit 9779797

File tree

1 file changed

+239
-92
lines changed

1 file changed

+239
-92
lines changed

tests/link/jax/test_scan.py

Lines changed: 239 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import pytest
55

66
import pytensor.tensor as pt
7-
from pytensor import function, shared
7+
from pytensor import function, ifelse, shared
88
from pytensor.compile import get_mode
99
from pytensor.configdefaults import config
1010
from pytensor.scan import until
1111
from pytensor.scan.basic import scan
1212
from pytensor.scan.op import Scan
1313
from pytensor.tensor import random
1414
from pytensor.tensor.math import gammaln, log
15-
from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector
15+
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
1616
from tests.link.jax.test_basic import compare_jax_and_py
1717

1818

@@ -189,96 +189,6 @@ def test_scan_while():
189189
compare_jax_and_py([], [xs], [])
190190

191191

192-
def test_scan_SEIR():
193-
"""Test a scan implementation of a SEIR model.
194-
195-
SEIR model definition:
196-
S[t+1] = S[t] - B[t]
197-
E[t+1] = E[t] +B[t] - C[t]
198-
I[t+1] = I[t+1] + C[t] - D[t]
199-
200-
B[t] ~ Binom(S[t], beta)
201-
C[t] ~ Binom(E[t], gamma)
202-
D[t] ~ Binom(I[t], delta)
203-
"""
204-
205-
def binomln(n, k):
206-
return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
207-
208-
def binom_log_prob(n, p, value):
209-
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
210-
211-
# sequences
212-
at_C = vector("C_t", dtype="int32", shape=(8,))
213-
at_D = vector("D_t", dtype="int32", shape=(8,))
214-
# outputs_info (initial conditions)
215-
st0 = lscalar("s_t0")
216-
et0 = lscalar("e_t0")
217-
it0 = lscalar("i_t0")
218-
logp_c = scalar("logp_c")
219-
logp_d = scalar("logp_d")
220-
# non_sequences
221-
beta = scalar("beta")
222-
gamma = scalar("gamma")
223-
delta = scalar("delta")
224-
225-
# TODO: Use random streams when their JAX conversions are implemented.
226-
# trng = pytensor.tensor.random.RandomStream(1234)
227-
228-
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
229-
# bt0 = trng.binomial(n=st0, p=beta)
230-
bt0 = st0 * beta
231-
bt0 = bt0.astype(st0.dtype)
232-
233-
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
234-
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
235-
236-
st1 = st0 - bt0
237-
et1 = et0 + bt0 - ct0
238-
it1 = it0 + ct0 - dt0
239-
return st1, et1, it1, logp_c1, logp_d1
240-
241-
(st, et, it, logp_c_all, logp_d_all), _ = scan(
242-
fn=seir_one_step,
243-
sequences=[at_C, at_D],
244-
outputs_info=[st0, et0, it0, logp_c, logp_d],
245-
non_sequences=[beta, gamma, delta],
246-
)
247-
st.name = "S_t"
248-
et.name = "E_t"
249-
it.name = "I_t"
250-
logp_c_all.name = "C_t_logp"
251-
logp_d_all.name = "D_t_logp"
252-
253-
s0, e0, i0 = 100, 50, 25
254-
logp_c0 = np.array(0.0, dtype=config.floatX)
255-
logp_d0 = np.array(0.0, dtype=config.floatX)
256-
beta_val, gamma_val, delta_val = (
257-
np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753]
258-
)
259-
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
260-
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
261-
262-
test_input_vals = [
263-
C,
264-
D,
265-
s0,
266-
e0,
267-
i0,
268-
logp_c0,
269-
logp_d0,
270-
beta_val,
271-
gamma_val,
272-
delta_val,
273-
]
274-
compare_jax_and_py(
275-
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
276-
[st, et, it, logp_c_all, logp_d_all],
277-
test_input_vals,
278-
jax_mode="JAX",
279-
)
280-
281-
282192
def test_scan_mitsot_with_nonseq():
283193
a_pt = scalar("a")
284194

@@ -420,3 +330,240 @@ def test_dynamic_sequence_length():
420330
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
421331
np.testing.assert_allclose(f([]), [])
422332
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))
333+
334+
335+
def SEIR_model_logp():
336+
"""Setup a Scan implementation of a SEIR model.
337+
338+
SEIR model definition:
339+
S[t+1] = S[t] - B[t]
340+
E[t+1] = E[t] +B[t] - C[t]
341+
I[t+1] = I[t+1] + C[t] - D[t]
342+
343+
B[t] ~ Binom(S[t], beta)
344+
C[t] ~ Binom(E[t], gamma)
345+
D[t] ~ Binom(I[t], delta)
346+
"""
347+
348+
def binomln(n, k):
349+
return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
350+
351+
def binom_log_prob(n, p, value):
352+
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
353+
354+
# sequences
355+
C_t = vector("C_t", dtype="int32", shape=(1200,))
356+
D_t = vector("D_t", dtype="int32", shape=(1200,))
357+
# outputs_info (initial conditions)
358+
st0 = scalar("s_t0")
359+
et0 = scalar("e_t0")
360+
it0 = scalar("i_t0")
361+
# non_sequences
362+
beta = scalar("beta")
363+
gamma = scalar("gamma")
364+
delta = scalar("delta")
365+
366+
def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta):
367+
# bt0 = trng.binomial(n=st0, p=beta)
368+
bt0 = st0 * beta
369+
bt0 = bt0.astype(st0.dtype)
370+
371+
logp_c1 = binom_log_prob(et0, gamma, ct0)
372+
logp_d1 = binom_log_prob(it0, delta, dt0)
373+
374+
st1 = st0 - bt0
375+
et1 = et0 + bt0 - ct0
376+
it1 = it0 + ct0 - dt0
377+
return st1, et1, it1, logp_c1, logp_d1
378+
379+
(st, et, it, logp_c_all, logp_d_all), _ = scan(
380+
fn=seir_one_step,
381+
sequences=[C_t, D_t],
382+
outputs_info=[st0, et0, it0, None, None],
383+
non_sequences=[beta, gamma, delta],
384+
)
385+
st.name = "S_t"
386+
et.name = "E_t"
387+
it.name = "I_t"
388+
logp_c_all.name = "C_t_logp"
389+
logp_d_all.name = "D_t_logp"
390+
391+
st0_val, et0_val, it0_val = np.array(100.0), np.array(50.0), np.array(25.0)
392+
beta_val, gamma_val, delta_val = (
393+
np.array(0.277792),
394+
np.array(0.135330),
395+
np.array(0.108753),
396+
)
397+
C_t_val = np.array([3, 5, 8, 13, 21, 26, 10, 3] * 150, dtype=np.int32)
398+
D_t_val = np.array([1, 2, 3, 7, 9, 11, 5, 1] * 150, dtype=np.int32)
399+
assert C_t_val.shape == D_t_val.shape == C_t.type.shape == D_t.type.shape
400+
401+
test_input_vals = [
402+
C_t_val,
403+
D_t_val,
404+
st0_val,
405+
et0_val,
406+
it0_val,
407+
beta_val,
408+
gamma_val,
409+
delta_val,
410+
]
411+
412+
loss_graph = logp_c_all.sum() + logp_d_all.sum()
413+
414+
return dict(
415+
graph_inputs=[C_t, D_t, st0, et0, it0, beta, gamma, delta],
416+
differentiable_vars=[st0, et0, it0, beta, gamma, delta],
417+
test_input_vals=test_input_vals,
418+
loss_graph=loss_graph,
419+
)
420+
421+
422+
def cyclical_reduction():
423+
"""Setup a Scan implementation of the cyclical reduction algorithm.
424+
425+
This solves the matrix equation A @ X @ X + B @ X + C = 0 for X
426+
427+
Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187
428+
"""
429+
430+
def stabilize(x, jitter=1e-16):
431+
return x + jitter * pt.eye(x.shape[0])
432+
433+
def step(A0, A1, A2, A1_hat, norm, step_num, tol):
434+
def cycle_step(A0, A1, A2, A1_hat, _norm, step_num):
435+
tmp = pt.dot(
436+
pt.vertical_stack(A0, A2),
437+
pt.linalg.solve(
438+
stabilize(A1),
439+
pt.horizontal_stack(A0, A2),
440+
assume_a="gen",
441+
check_finite=False,
442+
),
443+
)
444+
445+
n = A0.shape[0]
446+
idx_0 = pt.arange(n)
447+
idx_1 = idx_0 + n
448+
A1 = A1 - tmp[idx_0, :][:, idx_1] - tmp[idx_1, :][:, idx_0]
449+
A0 = -tmp[idx_0, :][:, idx_0]
450+
A2 = -tmp[idx_1, :][:, idx_1]
451+
A1_hat = A1_hat - tmp[idx_1, :][:, idx_0]
452+
453+
A0_L1_norm = pt.linalg.norm(A0, ord=1)
454+
455+
return A0, A1, A2, A1_hat, A0_L1_norm, step_num + 1
456+
457+
return ifelse(
458+
norm < tol,
459+
(A0, A1, A2, A1_hat, norm, step_num),
460+
cycle_step(A0, A1, A2, A1_hat, norm, step_num),
461+
)
462+
463+
A = pt.matrix("A", shape=(20, 20))
464+
B = pt.matrix("B", shape=(20, 20))
465+
C = pt.matrix("C", shape=(20, 20))
466+
467+
norm = np.array(1e9, dtype="float64")
468+
step_num = pt.zeros((), dtype="int32")
469+
max_iter = 100
470+
tol = 1e-7
471+
472+
(*_, A1_hat, norm, _n_steps), _ = scan(
473+
step,
474+
outputs_info=[A, B, C, B, norm, step_num],
475+
non_sequences=[tol],
476+
n_steps=max_iter,
477+
)
478+
A1_hat = A1_hat[-1]
479+
480+
T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False)
481+
482+
rng = np.random.default_rng(sum(map(ord, "cycle_reduction")))
483+
n = A.type.shape[0]
484+
A_test = rng.standard_normal(size=(n, n))
485+
C_test = rng.standard_normal(size=(n, n))
486+
# B must be invertible, so we make it symmetric positive-definite
487+
B_rand = rng.standard_normal(size=(n, n))
488+
B_test = B_rand @ B_rand.T + np.eye(n) * 1e-3
489+
490+
return dict(
491+
graph_inputs=[A, B, C],
492+
differentiable_vars=[A, B, C],
493+
test_input_vals=[A_test, B_test, C_test],
494+
loss_graph=pt.sum(T),
495+
)
496+
497+
498+
@pytest.mark.parametrize("gradient_backend", ["PYTENSOR", "JAX"])
499+
@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both"))
500+
@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp])
501+
def test_scan_benchmark(model, mode, gradient_backend, benchmark):
502+
if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"):
503+
pytest.skip("PYTENSOR backend does not support backward mode yet")
504+
505+
model_dict = model()
506+
graph_inputs = model_dict["graph_inputs"]
507+
differentiable_vars = model_dict["differentiable_vars"]
508+
loss_graph = model_dict["loss_graph"]
509+
test_input_vals = model_dict["test_input_vals"]
510+
511+
if gradient_backend == "PYTENSOR":
512+
backward_loss = pt.grad(
513+
loss_graph,
514+
wrt=differentiable_vars,
515+
)
516+
517+
match mode:
518+
# TODO: Restore original test separately
519+
case "0forward":
520+
graph_outputs = [loss_graph]
521+
case "1backward":
522+
graph_outputs = backward_loss
523+
case "2both":
524+
graph_outputs = [loss_graph, *backward_loss]
525+
case _:
526+
raise ValueError(f"Unknown mode: {mode}")
527+
528+
jax_fn, _ = compare_jax_and_py(
529+
graph_inputs,
530+
graph_outputs,
531+
test_input_vals,
532+
jax_mode="JAX",
533+
)
534+
jax_fn.trust_input = True
535+
536+
else: # gradient_backend == "JAX"
537+
import jax
538+
539+
loss_fn_tuple = function(graph_inputs, loss_graph, mode="JAX").vm.jit_fn
540+
541+
def loss_fn(*args):
542+
return loss_fn_tuple(*args)[0]
543+
544+
match mode:
545+
case "0forward":
546+
jax_fn = jax.jit(loss_fn_tuple)
547+
case "1backward":
548+
jax_fn = jax.jit(
549+
jax.grad(loss_fn, argnums=tuple(range(len(graph_inputs))[2:]))
550+
)
551+
case "2both":
552+
value_and_grad_fn = jax.value_and_grad(
553+
loss_fn, argnums=tuple(range(len(graph_inputs))[2:])
554+
)
555+
556+
@jax.jit
557+
def jax_fn(*args):
558+
loss, grads = value_and_grad_fn(*args)
559+
return loss, *grads
560+
561+
case _:
562+
raise ValueError(f"Unknown mode: {mode}")
563+
564+
def block_until_ready(*inputs, jax_fn=jax_fn):
565+
return [o.block_until_ready() for o in jax_fn(*inputs)]
566+
567+
block_until_ready(*test_input_vals) # Warmup
568+
569+
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)

0 commit comments

Comments
 (0)