|
4 | 4 | import pytest |
5 | 5 |
|
6 | 6 | import pytensor.tensor as pt |
7 | | -from pytensor import function, shared |
| 7 | +from pytensor import function, ifelse, shared |
8 | 8 | from pytensor.compile import get_mode |
9 | 9 | from pytensor.configdefaults import config |
10 | 10 | from pytensor.scan import until |
11 | 11 | from pytensor.scan.basic import scan |
12 | 12 | from pytensor.scan.op import Scan |
13 | 13 | from pytensor.tensor import random |
14 | 14 | from pytensor.tensor.math import gammaln, log |
15 | | -from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector |
| 15 | +from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector |
16 | 16 | from tests.link.jax.test_basic import compare_jax_and_py |
17 | 17 |
|
18 | 18 |
|
@@ -189,96 +189,6 @@ def test_scan_while(): |
189 | 189 | compare_jax_and_py([], [xs], []) |
190 | 190 |
|
191 | 191 |
|
192 | | -def test_scan_SEIR(): |
193 | | - """Test a scan implementation of a SEIR model. |
194 | | -
|
195 | | - SEIR model definition: |
196 | | - S[t+1] = S[t] - B[t] |
197 | | - E[t+1] = E[t] +B[t] - C[t] |
198 | | - I[t+1] = I[t+1] + C[t] - D[t] |
199 | | -
|
200 | | - B[t] ~ Binom(S[t], beta) |
201 | | - C[t] ~ Binom(E[t], gamma) |
202 | | - D[t] ~ Binom(I[t], delta) |
203 | | - """ |
204 | | - |
205 | | - def binomln(n, k): |
206 | | - return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) |
207 | | - |
208 | | - def binom_log_prob(n, p, value): |
209 | | - return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) |
210 | | - |
211 | | - # sequences |
212 | | - at_C = vector("C_t", dtype="int32", shape=(8,)) |
213 | | - at_D = vector("D_t", dtype="int32", shape=(8,)) |
214 | | - # outputs_info (initial conditions) |
215 | | - st0 = lscalar("s_t0") |
216 | | - et0 = lscalar("e_t0") |
217 | | - it0 = lscalar("i_t0") |
218 | | - logp_c = scalar("logp_c") |
219 | | - logp_d = scalar("logp_d") |
220 | | - # non_sequences |
221 | | - beta = scalar("beta") |
222 | | - gamma = scalar("gamma") |
223 | | - delta = scalar("delta") |
224 | | - |
225 | | - # TODO: Use random streams when their JAX conversions are implemented. |
226 | | - # trng = pytensor.tensor.random.RandomStream(1234) |
227 | | - |
228 | | - def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): |
229 | | - # bt0 = trng.binomial(n=st0, p=beta) |
230 | | - bt0 = st0 * beta |
231 | | - bt0 = bt0.astype(st0.dtype) |
232 | | - |
233 | | - logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) |
234 | | - logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) |
235 | | - |
236 | | - st1 = st0 - bt0 |
237 | | - et1 = et0 + bt0 - ct0 |
238 | | - it1 = it0 + ct0 - dt0 |
239 | | - return st1, et1, it1, logp_c1, logp_d1 |
240 | | - |
241 | | - (st, et, it, logp_c_all, logp_d_all), _ = scan( |
242 | | - fn=seir_one_step, |
243 | | - sequences=[at_C, at_D], |
244 | | - outputs_info=[st0, et0, it0, logp_c, logp_d], |
245 | | - non_sequences=[beta, gamma, delta], |
246 | | - ) |
247 | | - st.name = "S_t" |
248 | | - et.name = "E_t" |
249 | | - it.name = "I_t" |
250 | | - logp_c_all.name = "C_t_logp" |
251 | | - logp_d_all.name = "D_t_logp" |
252 | | - |
253 | | - s0, e0, i0 = 100, 50, 25 |
254 | | - logp_c0 = np.array(0.0, dtype=config.floatX) |
255 | | - logp_d0 = np.array(0.0, dtype=config.floatX) |
256 | | - beta_val, gamma_val, delta_val = ( |
257 | | - np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] |
258 | | - ) |
259 | | - C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) |
260 | | - D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) |
261 | | - |
262 | | - test_input_vals = [ |
263 | | - C, |
264 | | - D, |
265 | | - s0, |
266 | | - e0, |
267 | | - i0, |
268 | | - logp_c0, |
269 | | - logp_d0, |
270 | | - beta_val, |
271 | | - gamma_val, |
272 | | - delta_val, |
273 | | - ] |
274 | | - compare_jax_and_py( |
275 | | - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], |
276 | | - [st, et, it, logp_c_all, logp_d_all], |
277 | | - test_input_vals, |
278 | | - jax_mode="JAX", |
279 | | - ) |
280 | | - |
281 | | - |
282 | 192 | def test_scan_mitsot_with_nonseq(): |
283 | 193 | a_pt = scalar("a") |
284 | 194 |
|
@@ -420,3 +330,240 @@ def test_dynamic_sequence_length(): |
420 | 330 | assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 |
421 | 331 | np.testing.assert_allclose(f([]), []) |
422 | 332 | np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4])) |
| 333 | + |
| 334 | + |
| 335 | +def SEIR_model_logp(): |
| 336 | + """Setup a Scan implementation of a SEIR model. |
| 337 | +
|
| 338 | + SEIR model definition: |
| 339 | + S[t+1] = S[t] - B[t] |
| 340 | + E[t+1] = E[t] +B[t] - C[t] |
| 341 | + I[t+1] = I[t+1] + C[t] - D[t] |
| 342 | +
|
| 343 | + B[t] ~ Binom(S[t], beta) |
| 344 | + C[t] ~ Binom(E[t], gamma) |
| 345 | + D[t] ~ Binom(I[t], delta) |
| 346 | + """ |
| 347 | + |
| 348 | + def binomln(n, k): |
| 349 | + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) |
| 350 | + |
| 351 | + def binom_log_prob(n, p, value): |
| 352 | + return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) |
| 353 | + |
| 354 | + # sequences |
| 355 | + C_t = vector("C_t", dtype="int32", shape=(1200,)) |
| 356 | + D_t = vector("D_t", dtype="int32", shape=(1200,)) |
| 357 | + # outputs_info (initial conditions) |
| 358 | + st0 = scalar("s_t0") |
| 359 | + et0 = scalar("e_t0") |
| 360 | + it0 = scalar("i_t0") |
| 361 | + # non_sequences |
| 362 | + beta = scalar("beta") |
| 363 | + gamma = scalar("gamma") |
| 364 | + delta = scalar("delta") |
| 365 | + |
| 366 | + def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta): |
| 367 | + # bt0 = trng.binomial(n=st0, p=beta) |
| 368 | + bt0 = st0 * beta |
| 369 | + bt0 = bt0.astype(st0.dtype) |
| 370 | + |
| 371 | + logp_c1 = binom_log_prob(et0, gamma, ct0) |
| 372 | + logp_d1 = binom_log_prob(it0, delta, dt0) |
| 373 | + |
| 374 | + st1 = st0 - bt0 |
| 375 | + et1 = et0 + bt0 - ct0 |
| 376 | + it1 = it0 + ct0 - dt0 |
| 377 | + return st1, et1, it1, logp_c1, logp_d1 |
| 378 | + |
| 379 | + (st, et, it, logp_c_all, logp_d_all), _ = scan( |
| 380 | + fn=seir_one_step, |
| 381 | + sequences=[C_t, D_t], |
| 382 | + outputs_info=[st0, et0, it0, None, None], |
| 383 | + non_sequences=[beta, gamma, delta], |
| 384 | + ) |
| 385 | + st.name = "S_t" |
| 386 | + et.name = "E_t" |
| 387 | + it.name = "I_t" |
| 388 | + logp_c_all.name = "C_t_logp" |
| 389 | + logp_d_all.name = "D_t_logp" |
| 390 | + |
| 391 | + st0_val, et0_val, it0_val = np.array(100.0), np.array(50.0), np.array(25.0) |
| 392 | + beta_val, gamma_val, delta_val = ( |
| 393 | + np.array(0.277792), |
| 394 | + np.array(0.135330), |
| 395 | + np.array(0.108753), |
| 396 | + ) |
| 397 | + C_t_val = np.array([3, 5, 8, 13, 21, 26, 10, 3] * 150, dtype=np.int32) |
| 398 | + D_t_val = np.array([1, 2, 3, 7, 9, 11, 5, 1] * 150, dtype=np.int32) |
| 399 | + assert C_t_val.shape == D_t_val.shape == C_t.type.shape == D_t.type.shape |
| 400 | + |
| 401 | + test_input_vals = [ |
| 402 | + C_t_val, |
| 403 | + D_t_val, |
| 404 | + st0_val, |
| 405 | + et0_val, |
| 406 | + it0_val, |
| 407 | + beta_val, |
| 408 | + gamma_val, |
| 409 | + delta_val, |
| 410 | + ] |
| 411 | + |
| 412 | + loss_graph = logp_c_all.sum() + logp_d_all.sum() |
| 413 | + |
| 414 | + return dict( |
| 415 | + graph_inputs=[C_t, D_t, st0, et0, it0, beta, gamma, delta], |
| 416 | + differentiable_vars=[st0, et0, it0, beta, gamma, delta], |
| 417 | + test_input_vals=test_input_vals, |
| 418 | + loss_graph=loss_graph, |
| 419 | + ) |
| 420 | + |
| 421 | + |
| 422 | +def cyclical_reduction(): |
| 423 | + """Setup a Scan implementation of the cyclical reduction algorithm. |
| 424 | +
|
| 425 | + This solves the matrix equation A @ X @ X + B @ X + C = 0 for X |
| 426 | +
|
| 427 | + Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187 |
| 428 | + """ |
| 429 | + |
| 430 | + def stabilize(x, jitter=1e-16): |
| 431 | + return x + jitter * pt.eye(x.shape[0]) |
| 432 | + |
| 433 | + def step(A0, A1, A2, A1_hat, norm, step_num, tol): |
| 434 | + def cycle_step(A0, A1, A2, A1_hat, _norm, step_num): |
| 435 | + tmp = pt.dot( |
| 436 | + pt.vertical_stack(A0, A2), |
| 437 | + pt.linalg.solve( |
| 438 | + stabilize(A1), |
| 439 | + pt.horizontal_stack(A0, A2), |
| 440 | + assume_a="gen", |
| 441 | + check_finite=False, |
| 442 | + ), |
| 443 | + ) |
| 444 | + |
| 445 | + n = A0.shape[0] |
| 446 | + idx_0 = pt.arange(n) |
| 447 | + idx_1 = idx_0 + n |
| 448 | + A1 = A1 - tmp[idx_0, :][:, idx_1] - tmp[idx_1, :][:, idx_0] |
| 449 | + A0 = -tmp[idx_0, :][:, idx_0] |
| 450 | + A2 = -tmp[idx_1, :][:, idx_1] |
| 451 | + A1_hat = A1_hat - tmp[idx_1, :][:, idx_0] |
| 452 | + |
| 453 | + A0_L1_norm = pt.linalg.norm(A0, ord=1) |
| 454 | + |
| 455 | + return A0, A1, A2, A1_hat, A0_L1_norm, step_num + 1 |
| 456 | + |
| 457 | + return ifelse( |
| 458 | + norm < tol, |
| 459 | + (A0, A1, A2, A1_hat, norm, step_num), |
| 460 | + cycle_step(A0, A1, A2, A1_hat, norm, step_num), |
| 461 | + ) |
| 462 | + |
| 463 | + A = pt.matrix("A", shape=(20, 20)) |
| 464 | + B = pt.matrix("B", shape=(20, 20)) |
| 465 | + C = pt.matrix("C", shape=(20, 20)) |
| 466 | + |
| 467 | + norm = np.array(1e9, dtype="float64") |
| 468 | + step_num = pt.zeros((), dtype="int32") |
| 469 | + max_iter = 100 |
| 470 | + tol = 1e-7 |
| 471 | + |
| 472 | + (*_, A1_hat, norm, _n_steps), _ = scan( |
| 473 | + step, |
| 474 | + outputs_info=[A, B, C, B, norm, step_num], |
| 475 | + non_sequences=[tol], |
| 476 | + n_steps=max_iter, |
| 477 | + ) |
| 478 | + A1_hat = A1_hat[-1] |
| 479 | + |
| 480 | + T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False) |
| 481 | + |
| 482 | + rng = np.random.default_rng(sum(map(ord, "cycle_reduction"))) |
| 483 | + n = A.type.shape[0] |
| 484 | + A_test = rng.standard_normal(size=(n, n)) |
| 485 | + C_test = rng.standard_normal(size=(n, n)) |
| 486 | + # B must be invertible, so we make it symmetric positive-definite |
| 487 | + B_rand = rng.standard_normal(size=(n, n)) |
| 488 | + B_test = B_rand @ B_rand.T + np.eye(n) * 1e-3 |
| 489 | + |
| 490 | + return dict( |
| 491 | + graph_inputs=[A, B, C], |
| 492 | + differentiable_vars=[A, B, C], |
| 493 | + test_input_vals=[A_test, B_test, C_test], |
| 494 | + loss_graph=pt.sum(T), |
| 495 | + ) |
| 496 | + |
| 497 | + |
| 498 | +@pytest.mark.parametrize("gradient_backend", ["PYTENSOR", "JAX"]) |
| 499 | +@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both")) |
| 500 | +@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp]) |
| 501 | +def test_scan_benchmark(model, mode, gradient_backend, benchmark): |
| 502 | + if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"): |
| 503 | + pytest.skip("PYTENSOR backend does not support backward mode yet") |
| 504 | + |
| 505 | + model_dict = model() |
| 506 | + graph_inputs = model_dict["graph_inputs"] |
| 507 | + differentiable_vars = model_dict["differentiable_vars"] |
| 508 | + loss_graph = model_dict["loss_graph"] |
| 509 | + test_input_vals = model_dict["test_input_vals"] |
| 510 | + |
| 511 | + if gradient_backend == "PYTENSOR": |
| 512 | + backward_loss = pt.grad( |
| 513 | + loss_graph, |
| 514 | + wrt=differentiable_vars, |
| 515 | + ) |
| 516 | + |
| 517 | + match mode: |
| 518 | + # TODO: Restore original test separately |
| 519 | + case "0forward": |
| 520 | + graph_outputs = [loss_graph] |
| 521 | + case "1backward": |
| 522 | + graph_outputs = backward_loss |
| 523 | + case "2both": |
| 524 | + graph_outputs = [loss_graph, *backward_loss] |
| 525 | + case _: |
| 526 | + raise ValueError(f"Unknown mode: {mode}") |
| 527 | + |
| 528 | + jax_fn, _ = compare_jax_and_py( |
| 529 | + graph_inputs, |
| 530 | + graph_outputs, |
| 531 | + test_input_vals, |
| 532 | + jax_mode="JAX", |
| 533 | + ) |
| 534 | + jax_fn.trust_input = True |
| 535 | + |
| 536 | + else: # gradient_backend == "JAX" |
| 537 | + import jax |
| 538 | + |
| 539 | + loss_fn_tuple = function(graph_inputs, loss_graph, mode="JAX").vm.jit_fn |
| 540 | + |
| 541 | + def loss_fn(*args): |
| 542 | + return loss_fn_tuple(*args)[0] |
| 543 | + |
| 544 | + match mode: |
| 545 | + case "0forward": |
| 546 | + jax_fn = jax.jit(loss_fn_tuple) |
| 547 | + case "1backward": |
| 548 | + jax_fn = jax.jit( |
| 549 | + jax.grad(loss_fn, argnums=tuple(range(len(graph_inputs))[2:])) |
| 550 | + ) |
| 551 | + case "2both": |
| 552 | + value_and_grad_fn = jax.value_and_grad( |
| 553 | + loss_fn, argnums=tuple(range(len(graph_inputs))[2:]) |
| 554 | + ) |
| 555 | + |
| 556 | + @jax.jit |
| 557 | + def jax_fn(*args): |
| 558 | + loss, grads = value_and_grad_fn(*args) |
| 559 | + return loss, *grads |
| 560 | + |
| 561 | + case _: |
| 562 | + raise ValueError(f"Unknown mode: {mode}") |
| 563 | + |
| 564 | + def block_until_ready(*inputs, jax_fn=jax_fn): |
| 565 | + return [o.block_until_ready() for o in jax_fn(*inputs)] |
| 566 | + |
| 567 | + block_until_ready(*test_input_vals) # Warmup |
| 568 | + |
| 569 | + benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1) |
0 commit comments