3030def set_optimizer_function_defaults (method , use_grad , use_hess , use_hessp ):
3131 method_info = MINIMIZE_MODE_KWARGS [method ].copy ()
3232
33+ if use_hess and use_hessp :
34+ _log .warning (
35+ 'Both "use_hess" and "use_hessp" are set to True. scipy.optimize.minimize never uses both at the '
36+ 'same time. Setting "use_hess" to False.'
37+ )
38+ use_hess = False
39+
3340 use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
3441 use_hess = use_hess if use_hess is not None else method_info ["uses_hess" ]
3542 use_hessp = use_hessp if use_hessp is not None else method_info ["uses_hessp" ]
3643
37- if use_hess and use_hessp :
38- use_hess = False
39-
4044 return use_grad , use_hess , use_hessp
4145
4246
@@ -97,7 +101,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97101 return f_untransform (posterior_draws )
98102
99103
100- def _compile_jax_gradients (
104+ def _compile_grad_and_hess_to_jax (
101105 f_loss : Function , use_hess : bool , use_hessp : bool
102106) -> tuple [Callable | None , Callable | None ]:
103107 """
@@ -152,7 +156,7 @@ def f_hess_jax(x):
152156 return f_loss_and_grad , f_hess , f_hessp
153157
154158
155- def _compile_functions (
159+ def _compile_functions_for_scipy_optimize (
156160 loss : TensorVariable ,
157161 inputs : list [TensorVariable ],
158162 compute_grad : bool ,
@@ -177,7 +181,7 @@ def _compile_functions(
177181 compute_hessp: bool
178182 Whether to compile a function that computes the Hessian-vector product of the loss function.
179183 compile_kwargs: dict, optional
180- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
184+ Additional keyword arguments to pass to the ``pm.compile `` function.
181185
182186 Returns
183187 -------
@@ -193,19 +197,19 @@ def _compile_functions(
193197 if compute_grad :
194198 grads = pytensor .gradient .grad (loss , inputs )
195199 grad = pt .concatenate ([grad .ravel () for grad in grads ])
196- f_loss_and_grad = pm .compile_pymc (inputs , [loss , grad ], ** compile_kwargs )
200+ f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
197201 else :
198- f_loss = pm .compile_pymc (inputs , loss , ** compile_kwargs )
202+ f_loss = pm .compile (inputs , loss , ** compile_kwargs )
199203 return [f_loss ]
200204
201205 if compute_hess :
202206 hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
203- f_hess = pm .compile_pymc (inputs , hess , ** compile_kwargs )
207+ f_hess = pm .compile (inputs , hess , ** compile_kwargs )
204208
205209 if compute_hessp :
206210 p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
207211 hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
208- f_hessp = pm .compile_pymc ([* inputs , p ], hessp [0 ], ** compile_kwargs )
212+ f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
209213
210214 return [f_loss_and_grad , f_hess , f_hessp ]
211215
@@ -240,7 +244,7 @@ def scipy_optimize_funcs_from_loss(
240244 gradient_backend: str, default "pytensor"
241245 Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242246 compile_kwargs:
243- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
247+ Additional keyword arguments to pass to the ``pm.compile `` function.
244248
245249 Returns
246250 -------
@@ -285,7 +289,7 @@ def scipy_optimize_funcs_from_loss(
285289 compute_hess = use_hess and not use_jax_gradients
286290 compute_hessp = use_hessp and not use_jax_gradients
287291
288- funcs = _compile_functions (
292+ funcs = _compile_functions_for_scipy_optimize (
289293 loss = loss ,
290294 inputs = [flat_input ],
291295 compute_grad = compute_grad ,
@@ -301,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
301305
302306 if use_jax_gradients :
303307 # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304- f_loss , f_hess , f_hessp = _compile_jax_gradients (f_loss , use_hess , use_hessp )
308+ f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
305309
306310 return f_loss , f_hess , f_hessp
307311
0 commit comments