@@ -419,28 +419,42 @@ def sample_laplace_posterior(
419
419
420
420
def find_mode (
421
421
inputs : list [TensorVariable ],
422
+ params : dict , # TODO Would be nice to automatically map this to inputs somehow: {k.name: ... for k in inputs}
423
+ x0 : TensorVariable
424
+ | None = None , # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is
422
425
x : TensorVariable | None = None ,
423
426
model : pm .Model | None = None ,
424
427
method : minimize_method = "BFGS" ,
428
+ jac : bool = True ,
429
+ hess : bool = False ,
425
430
optimizer_kwargs : dict | None = None ,
426
- ): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind
431
+ ): # TODO Output type is list of same type as x0
427
432
model = pm .modelcontext (model )
428
433
if x is None :
429
- raise NotImplementedError ("Currently assumes user specifies the Gaussian latent field x" )
434
+ raise UserWarning (
435
+ "Latent Gaussian field x unspecified. Assuming it is the first entry in inputs. Specify which input to obtain the mode over using the input x."
436
+ )
437
+ x = inputs [0 ]
438
+
439
+ if x0 is None :
440
+ # Should return a random numpy array of the same shape as x0 - not sure how to get the shape of x0
441
+ raise NotImplementedError
430
442
431
443
# Minimise negative log likelihood
432
- # TODO: NLL already computs jac by default. Need to check how to access
433
- loss_x = - model .logp ()
434
- # TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific
435
- # variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow
436
- loss = pytensor .function (inputs , loss_x )
437
-
438
- grad = pytensor .gradient .grad (loss , inputs )
439
- hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
440
-
441
- # Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here
442
- res = minimize (loss , x , method , grad , hess , optimizer_kwargs )
443
- return res .x , res .hess
444
+ nll = - model .logp ()
445
+ soln , _ = minimize (
446
+ objective = nll , x = x , method = method , jac = jac , hess = hess , optimizer_kwargs = optimizer_kwargs
447
+ )
448
+
449
+ get_mode = pytensor .function (inputs , soln )
450
+ mode = get_mode (x0 , ** params )
451
+
452
+ # Calculate the value of the Hessian at the mode
453
+ # TODO check if we can't pull this out of the soln graph when jac or hess=True
454
+ hess_x = pytensor .gradient .hessian (nll , x )
455
+ hess = pytensor .function (inputs , hess_x )
456
+
457
+ return mode , hess (mode , ** params )
444
458
445
459
446
460
def fit_laplace (
0 commit comments