Skip to content

Commit be39ef6

Browse files
Add regression test for #1550
1 parent 4185167 commit be39ef6

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/tensor/test_optimize.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,65 @@ def f(b0, b1, b2, b3, x0, x1, x2, x3, y):
204204
utt.verify_grad(f, inputs, eps=1e-6)
205205

206206

207+
def test_minimize_mvn_logp_mu_and_cov():
208+
"""Regression test for https://github.com/pymc-devs/pytensor/issues/1550"""
209+
d = 3
210+
211+
def objective(mu, cov, data):
212+
L = pt.linalg.cholesky(cov)
213+
_, logdet = pt.linalg.slogdet(cov)
214+
215+
v = mu - data
216+
y = pt.linalg.solve_triangular(L, v, lower=True)
217+
quad_term = (y**2).sum()
218+
219+
return 0.5 * (d * pt.log(2 * np.pi) + logdet + quad_term)
220+
221+
data = pt.vector("data", shape=(d,))
222+
mu = pt.vector("mu", shape=(d,))
223+
cov = pt.dmatrix("cov", shape=(d, d))
224+
225+
neg_logp = objective(mu, cov, data)
226+
mu_star, success = minimize(
227+
objective=neg_logp,
228+
x=mu,
229+
method="BFGS",
230+
jac=True,
231+
hess=False,
232+
use_vectorized_jac=True,
233+
)
234+
235+
# This replace + gradient was the original source of the error in #1550, check that no longer raises
236+
y_star = pytensor.graph_replace(neg_logp, {mu: mu_star})
237+
_ = pt.grad(y_star, [mu, cov, data])
238+
239+
rng = np.random.default_rng()
240+
data_val = rng.normal(size=(d,)).astype(floatX)
241+
242+
L = rng.normal(size=(d, d)).astype(floatX)
243+
cov_val = L @ L.T
244+
mu0_val = rng.normal(size=(d,)).astype(floatX)
245+
246+
fn = pytensor.function([mu, cov, data], [mu_star, success])
247+
_, success_flag = fn(mu0_val, cov_val, data_val)
248+
assert success_flag
249+
250+
def min_fn(mu, cov, data):
251+
mu_star, _ = minimize(
252+
objective=objective(mu, cov, data),
253+
x=mu,
254+
method="BFGS",
255+
jac=True,
256+
hess=False,
257+
use_vectorized_jac=True,
258+
)
259+
return mu_star.sum()
260+
261+
utt.verify_grad(
262+
min_fn, [mu0_val, cov_val, data_val], eps=1e-3 if floatX == "float32" else 1e-6
263+
)
264+
265+
207266
@pytest.mark.parametrize(
208267
"method, jac, hess",
209268
[("secant", False, False), ("newton", True, False), ("halley", True, True)],

0 commit comments

Comments
 (0)