@@ -114,14 +114,14 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
114114
115115
116116def _compile_grad_and_hess_to_jax (
117- f_loss : Function , use_hess : bool , use_hessp : bool
117+ f_fused : Function , use_hess : bool , use_hessp : bool
118118) -> tuple [Callable | None , Callable | None ]:
119119 """
120120 Compile loss function gradients using JAX.
121121
122122 Parameters
123123 ----------
124- f_loss : Function
124+ f_fused : Function
125125 The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
126126 compiled with mode="JAX".
127127 use_hess: bool
@@ -131,43 +131,40 @@ def _compile_grad_and_hess_to_jax(
131131
132132 Returns
133133 -------
134- f_loss_and_grad: Callable
135- The compiled loss function and gradient function.
136- f_hess: Callable | None
137- The compiled hessian function, or None if use_hess is False.
134+ f_fused: Callable
135+ The compiled loss function and gradient function, which may also compute the hessian if requested.
138136 f_hessp: Callable | None
139137 The compiled hessian-vector product function, or None if use_hessp is False.
140138 """
141139 import jax
142140
143- f_hess = None
144141 f_hessp = None
145142
146- orig_loss_fn = f_loss .vm .jit_fn
143+ orig_loss_fn = f_fused .vm .jit_fn
147144
148- @jax .jit
149- def loss_fn_jax_grad (x ):
150- return jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
145+ if use_hess :
146+
147+ @jax .jit
148+ def loss_fn_fused (x ):
149+ loss_and_grad = jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
150+ hess = jax .hessian (lambda x : orig_loss_fn (x )[0 ])(x )
151+ return * loss_and_grad , hess
152+
153+ else :
151154
152- f_loss_and_grad = loss_fn_jax_grad
155+ @jax .jit
156+ def loss_fn_fused (x ):
157+ return jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
153158
154159 if use_hessp :
155160
156161 def f_hessp_jax (x , p ):
157- y , u = jax .jvp (lambda x : f_loss_and_grad (x )[1 ], (x ,), (p ,))
162+ y , u = jax .jvp (lambda x : loss_fn_fused (x )[1 ], (x ,), (p ,))
158163 return jax .numpy .stack (u )
159164
160165 f_hessp = jax .jit (f_hessp_jax )
161166
162- if use_hess :
163- _f_hess_jax = jax .jacfwd (lambda x : f_loss_and_grad (x )[1 ])
164-
165- def f_hess_jax (x ):
166- return jax .numpy .stack (_f_hess_jax (x ))
167-
168- f_hess = jax .jit (f_hess_jax )
169-
170- return f_loss_and_grad , f_hess , f_hessp
167+ return loss_fn_fused , f_hessp
171168
172169
173170def _compile_functions_for_scipy_optimize (
@@ -199,33 +196,47 @@ def _compile_functions_for_scipy_optimize(
199196
200197 Returns
201198 -------
202- f_loss: Function
203-
204- f_hess: Function | None
199+ f_fused: Function
200+ The compiled loss function, which may also include gradients and hessian if requested.
205201 f_hessp: Function | None
202+ The compiled hessian-vector product function, or None if compute_hessp is False.
206203 """
204+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
205+
207206 loss = pm .pytensorf .rewrite_pregrad (loss )
208- f_hess = None
209207 f_hessp = None
210208
211- if compute_grad :
212- grads = pytensor .gradient .grad (loss , inputs )
213- grad = pt .concatenate ([grad .ravel () for grad in grads ])
214- f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
215- else :
209+ # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent
210+ # with the case where we also compute gradients, hessians, or hessian-vector products.
211+ if not (compute_grad or compute_hess or compute_hessp ):
216212 f_loss = pm .compile (inputs , loss , ** compile_kwargs )
217213 return [f_loss ]
218214
219- if compute_hess :
220- hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
221- f_hess = pm .compile (inputs , hess , ** compile_kwargs )
215+ # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single
216+ # fused function and retun it. If the user also wants the hession, the fused function will return the loss,
217+ # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss
218+ # and gradients, and a separate function for the hessian-vector product.
222219
223220 if compute_hessp :
221+ # Handle this first, since it can be compiled alone.
224222 p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
225223 hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
226224 f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
227225
228- return [f_loss_and_grad , f_hess , f_hessp ]
226+ outputs = [loss ]
227+
228+ if compute_grad :
229+ grads = pytensor .gradient .grad (loss , inputs )
230+ grad = pt .concatenate ([grad .ravel () for grad in grads ])
231+ outputs .append (grad )
232+
233+ if compute_hess :
234+ hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
235+ outputs .append (hess )
236+
237+ f_fused = pm .compile (inputs , outputs , ** compile_kwargs )
238+
239+ return [f_fused , f_hessp ]
229240
230241
231242def scipy_optimize_funcs_from_loss (
@@ -262,10 +273,8 @@ def scipy_optimize_funcs_from_loss(
262273
263274 Returns
264275 -------
265- f_loss: Callable
266- The compiled loss function.
267- f_hess: Callable | None
268- The compiled hessian function, or None if use_hess is False.
276+ f_fused: Callable
277+ The compiled loss function, which may also include gradients and hessian if requested.
269278 f_hessp: Callable | None
270279 The compiled hessian-vector product function, or None if use_hessp is False.
271280 """
@@ -322,16 +331,15 @@ def scipy_optimize_funcs_from_loss(
322331 compile_kwargs = compile_kwargs ,
323332 )
324333
325- # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
326- f_loss = funcs . pop ( 0 )
327- f_hess = funcs .pop (0 ) if compute_grad else None
328- f_hessp = funcs .pop (0 ) if compute_grad else None
334+ # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients,
335+ # or the loss function with gradients and hessian.
336+ f_fused = funcs .pop (0 )
337+ f_hessp = funcs .pop (0 ) if compute_hessp else None
329338
330339 if use_jax_gradients :
331- # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
332- f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
340+ f_fused , f_hessp = _compile_grad_and_hess_to_jax (f_fused , use_hess , use_hessp )
333341
334- return f_loss , f_hess , f_hessp
342+ return f_fused , f_hessp
335343
336344
337345def find_MAP (
@@ -434,7 +442,7 @@ def find_MAP(
434442 method , use_grad , use_hess , use_hessp
435443 )
436444
437- f_logp , f_hess , f_hessp = scipy_optimize_funcs_from_loss (
445+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
438446 loss = - frozen_model .logp (jacobian = False ),
439447 inputs = frozen_model .continuous_value_vars + frozen_model .discrete_value_vars ,
440448 initial_point_dict = start_dict ,
@@ -445,23 +453,21 @@ def find_MAP(
445453 compile_kwargs = compile_kwargs ,
446454 )
447455
448- args = optimizer_kwargs .pop ("args" , None )
456+ args = optimizer_kwargs .pop ("args" , () )
449457
450458 # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
451459 # if so. That is why the jac argument is not passed here in either branch.
452460
453461 if do_basinhopping :
454462 if "args" not in minimizer_kwargs :
455463 minimizer_kwargs ["args" ] = args
456- if "hess" not in minimizer_kwargs :
457- minimizer_kwargs ["hess" ] = f_hess
458464 if "hessp" not in minimizer_kwargs :
459465 minimizer_kwargs ["hessp" ] = f_hessp
460466 if "method" not in minimizer_kwargs :
461467 minimizer_kwargs ["method" ] = method
462468
463469 optimizer_result = basinhopping (
464- func = f_logp ,
470+ func = f_fused ,
465471 x0 = cast (np .ndarray [float ], initial_params .data ),
466472 progressbar = progressbar ,
467473 minimizer_kwargs = minimizer_kwargs ,
@@ -470,10 +476,9 @@ def find_MAP(
470476
471477 else :
472478 optimizer_result = minimize (
473- f = f_logp ,
479+ f = f_fused ,
474480 x0 = cast (np .ndarray [float ], initial_params .data ),
475481 args = args ,
476- hess = f_hess ,
477482 hessp = f_hessp ,
478483 progressbar = progressbar ,
479484 method = method ,
@@ -486,6 +491,33 @@ def find_MAP(
486491 DictToArrayBijection .rmap (raveled_optimized )
487492 )
488493
494+ # Downstream computation will probably want the covaraince matrix at the optimized point, so we compute it here,
495+ # while we still have access to the compiled function.
496+ x_star = optimizer_result .x
497+ n_vars = len (x_star )
498+
499+ if method == "BFGS" :
500+ # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than
501+ # re-computing something
502+ getattr (optimizer_result , "hess_inv" , None )
503+ elif method == "L-BFGS-B" :
504+ # Here we will have a LinearOperator representing the inverse Hessian-Vector product.
505+ f_hessp_inv = optimizer_result .hess_inv
506+ basis = np .eye (n_vars )
507+ np .stack ([f_hessp_inv (basis [:, i ]) for i in range (n_vars )], axis = - 1 )
508+
509+ elif f_hessp is not None :
510+ # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from
511+ # the hessp function, using euclidian basis vector.
512+ basis = np .eye (n_vars )
513+ H = np .stack ([f_hessp (optimizer_result .x , basis [:, i ]) for i in range (n_vars )], axis = - 1 )
514+ np .linalg .inv (get_nearest_psd (H ))
515+
516+ elif use_hess :
517+ # If we compiled a hessian function, just use it
518+ _ , _ , H = f_fused (x_star )
519+ np .linalg .inv (get_nearest_psd (H ))
520+
489521 optimized_point = {
490522 var .name : value for var , value in zip (unobserved_vars , unobserved_vars_values )
491523 }
0 commit comments