|
8 | 8 | from pytensor.graph import Apply, Op, Type |
9 | 9 | from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar |
10 | 10 | from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar |
| 11 | +from pytensor.tensor.special import betaln |
11 | 12 | from tests import unittest_tools as utt |
12 | 13 |
|
13 | 14 |
|
@@ -523,3 +524,44 @@ def L_op(self, inputs, outputs, output_gradients): |
523 | 524 | _grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise") |
524 | 525 | # np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1) |
525 | 526 | # np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1) |
| 527 | + |
| 528 | + |
| 529 | +def test_vectorize_root_gradients(): |
| 530 | + """Regression test for https://github.com/pymc-devs/pytensor/issues/1586""" |
| 531 | + n, a, b = pt.scalars("n a b".split()) |
| 532 | + w_min, w_max = pt.scalars("w_min w_max".split()) |
| 533 | + |
| 534 | + w_support = pt.linspace(w_min, w_max, n + 1) |
| 535 | + |
| 536 | + k = pt.floor(w_support) |
| 537 | + ln_n_choose_k = -pt.log(n + 1) - betaln(n - k + 1, k + 1) |
| 538 | + q_probs = pt.exp(ln_n_choose_k + betaln(k + a, n - k + b) - betaln(a, b)) |
| 539 | + |
| 540 | + c = pt.dscalar("c") # Unemployment benefit |
| 541 | + β = pt.dscalar("β") # Discount rate |
| 542 | + |
| 543 | + # initial value function guess |
| 544 | + v0 = pt.dvector("v0") |
| 545 | + |
| 546 | + # Fixed-point operator |
| 547 | + T = pt.maximum(w_support / (1 - β), c + β * pt.dot(v0, q_probs)) |
| 548 | + |
| 549 | + v_star, _ = pt.optimize.root(equations=T - v0, variables=v0, method="hybr") |
| 550 | + |
| 551 | + w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs)) |
| 552 | + |
| 553 | + # We want to study the impact of change in unemployment and patience on the reserve wage |
| 554 | + w_grads = pt.grad(w_bar, [c, β]) |
| 555 | + |
| 556 | + c_grid = pt.dmatrix("c_grid") |
| 557 | + β_grid = pt.dmatrix("β_grid") |
| 558 | + |
| 559 | + w_bar_grid, *w_grad_grid = pytensor.graph.vectorize_graph( |
| 560 | + [w_bar, *w_grads], {β: β_grid, c: c_grid} |
| 561 | + ) |
| 562 | + |
| 563 | + _ = pytensor.function( |
| 564 | + [v0, c_grid, β_grid, n, a, b, w_min, w_max], |
| 565 | + [w_bar_grid, *w_grad_grid], |
| 566 | + on_unused_input="ignore", |
| 567 | + ) |
0 commit comments