Skip to content

Commit d080317

Browse files
Add regression test for #1586
1 parent 94e6f8f commit d080317

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/tensor/test_optimize.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.graph import Apply, Op, Type
99
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
1010
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
11+
from pytensor.tensor.special import betaln
1112
from tests import unittest_tools as utt
1213

1314

@@ -523,3 +524,44 @@ def L_op(self, inputs, outputs, output_gradients):
523524
_grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise")
524525
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1)
525526
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1)
527+
528+
529+
def test_vectorize_root_gradients():
530+
"""Regression test for https://github.com/pymc-devs/pytensor/issues/1586"""
531+
n, a, b = pt.scalars("n a b".split())
532+
w_min, w_max = pt.scalars("w_min w_max".split())
533+
534+
w_support = pt.linspace(w_min, w_max, n + 1)
535+
536+
k = pt.floor(w_support)
537+
ln_n_choose_k = -pt.log(n + 1) - betaln(n - k + 1, k + 1)
538+
q_probs = pt.exp(ln_n_choose_k + betaln(k + a, n - k + b) - betaln(a, b))
539+
540+
c = pt.dscalar("c") # Unemployment benefit
541+
β = pt.dscalar("β") # Discount rate
542+
543+
# initial value function guess
544+
v0 = pt.dvector("v0")
545+
546+
# Fixed-point operator
547+
T = pt.maximum(w_support / (1 - β), c + β * pt.dot(v0, q_probs))
548+
549+
v_star, _ = pt.optimize.root(equations=T - v0, variables=v0, method="hybr")
550+
551+
w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs))
552+
553+
# We want to study the impact of change in unemployment and patience on the reserve wage
554+
w_grads = pt.grad(w_bar, [c, β])
555+
556+
c_grid = pt.dmatrix("c_grid")
557+
β_grid = pt.dmatrix("β_grid")
558+
559+
w_bar_grid, *w_grad_grid = pytensor.graph.vectorize_graph(
560+
[w_bar, *w_grads], {β: β_grid, c: c_grid}
561+
)
562+
563+
_ = pytensor.function(
564+
[v0, c_grid, β_grid, n, a, b, w_min, w_max],
565+
[w_bar_grid, *w_grad_grid],
566+
on_unused_input="ignore",
567+
)

0 commit comments

Comments
 (0)