@@ -204,6 +204,65 @@ def f(b0, b1, b2, b3, x0, x1, x2, x3, y):
204204 utt .verify_grad (f , inputs , eps = 1e-6 )
205205
206206
207+ def test_minimize_mvn_logp_mu_and_cov ():
208+ """Regression test for https://github.com/pymc-devs/pytensor/issues/1550"""
209+ d = 3
210+
211+ def objective (mu , cov , data ):
212+ L = pt .linalg .cholesky (cov )
213+ _ , logdet = pt .linalg .slogdet (cov )
214+
215+ v = mu - data
216+ y = pt .linalg .solve_triangular (L , v , lower = True )
217+ quad_term = (y ** 2 ).sum ()
218+
219+ return 0.5 * (d * pt .log (2 * np .pi ) + logdet + quad_term )
220+
221+ data = pt .vector ("data" , shape = (d ,))
222+ mu = pt .vector ("mu" , shape = (d ,))
223+ cov = pt .dmatrix ("cov" , shape = (d , d ))
224+
225+ neg_logp = objective (mu , cov , data )
226+ mu_star , success = minimize (
227+ objective = neg_logp ,
228+ x = mu ,
229+ method = "BFGS" ,
230+ jac = True ,
231+ hess = False ,
232+ use_vectorized_jac = True ,
233+ )
234+
235+ # This replace + gradient was the original source of the error in #1550, check that no longer raises
236+ y_star = pytensor .graph_replace (neg_logp , {mu : mu_star })
237+ _ = pt .grad (y_star , [mu , cov , data ])
238+
239+ rng = np .random .default_rng ()
240+ data_val = rng .normal (size = (d ,)).astype (floatX )
241+
242+ L = rng .normal (size = (d , d )).astype (floatX )
243+ cov_val = L @ L .T
244+ mu0_val = rng .normal (size = (d ,)).astype (floatX )
245+
246+ fn = pytensor .function ([mu , cov , data ], [mu_star , success ])
247+ _ , success_flag = fn (mu0_val , cov_val , data_val )
248+ assert success_flag
249+
250+ def min_fn (mu , cov , data ):
251+ mu_star , _ = minimize (
252+ objective = objective (mu , cov , data ),
253+ x = mu ,
254+ method = "BFGS" ,
255+ jac = True ,
256+ hess = False ,
257+ use_vectorized_jac = True ,
258+ )
259+ return mu_star .sum ()
260+
261+ utt .verify_grad (
262+ min_fn , [mu0_val , cov_val , data_val ], eps = 1e-3 if floatX == "float32" else 1e-6
263+ )
264+
265+
207266@pytest .mark .parametrize (
208267 "method, jac, hess" ,
209268 [("secant" , False , False ), ("newton" , True , False ), ("halley" , True , True )],
0 commit comments