From 30090ed836b1ada2bda3e3c1adad0f9dd81344c7 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Tue, 10 Jun 2025 20:19:50 +1000 Subject: [PATCH 01/16] set up skeleton for find_mode --- pymc_extras/inference/laplace.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index d86f1fee5..2c0d220b0 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -39,6 +39,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames +from pytensor.tensor import TensorVariable +from pytensor.tensor.optimize import minimize from scipy import stats from pymc_extras.inference.find_map import ( @@ -415,6 +417,31 @@ def sample_laplace_posterior( return idata +def find_mode( + inputs: list[TensorVariable], + x: TensorVariable | None = None, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + optimizer_kwargs: dict | None = None, +): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind + model = pm.modelcontext(model) + if x is None: + raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x") + + # Minimise negative log likelihood + loss_x = -model.logp() + # TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific + # variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow + loss = pytensor.function(inputs, loss_x) + + grad = pytensor.gradient.grad(loss, inputs) + hess = pytensor.gradient.jacobian(grad, inputs)[0] + + # Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here + res = minimize(loss, x, method, grad, hess, optimizer_kwargs) + return res.x, res.hess + + def fit_laplace( optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", *, From a54c7b25a036cf1fa2ab55c22b94c8ed91b86795 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Thu, 12 Jun 2025 22:48:45 +1000 Subject: [PATCH 02/16] added TODO --- pymc_extras/inference/laplace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 2c0d220b0..cc614781d 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -429,6 +429,7 @@ def find_mode( raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x") # Minimise negative log likelihood + # TODO: NLL already computs jac by default. Need to check how to access loss_x = -model.logp() # TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific # variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow From 23b4970424c0ae31234f08d22cc90a98d95febc5 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Sat, 14 Jun 2025 23:29:30 +1000 Subject: [PATCH 03/16] moved notebook testing code into find_mode --- pymc_extras/inference/laplace.py | 42 +++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index cc614781d..36b31f8b1 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -419,28 +419,42 @@ def sample_laplace_posterior( def find_mode( inputs: list[TensorVariable], + params: dict, # TODO Would be nice to automatically map this to inputs somehow: {k.name: ... for k in inputs} + x0: TensorVariable + | None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is x: TensorVariable | None = None, model: pm.Model | None = None, method: minimize_method = "BFGS", + jac: bool = True, + hess: bool = False, optimizer_kwargs: dict | None = None, -): # Unsure of the return type, I'd assume it would be a list of pt tensors of some kind +): # TODO Output type is list of same type as x0 model = pm.modelcontext(model) if x is None: - raise NotImplementedError("Currently assumes user specifies the Gaussian latent field x") + raise UserWarning( + "Latent Gaussian field x unspecified. Assuming it is the first entry in inputs. Specify which input to obtain the mode over using the input x." + ) + x = inputs[0] + + if x0 is None: + # Should return a random numpy array of the same shape as x0 - not sure how to get the shape of x0 + raise NotImplementedError # Minimise negative log likelihood - # TODO: NLL already computs jac by default. Need to check how to access - loss_x = -model.logp() - # TODO Need to think about how to get inputs (i.e. a collection of all the input variables) to go along with the specific - # variable x, i.e. f(x, *args). I assume I can't assume that the inputs arg will be ordered to have x first. May need to sort it somehow - loss = pytensor.function(inputs, loss_x) - - grad = pytensor.gradient.grad(loss, inputs) - hess = pytensor.gradient.jacobian(grad, inputs)[0] - - # Need to play around with scipy.optimize.minimize with pytensor a little so I can figure out if it's "x" or "inputs" that goes here - res = minimize(loss, x, method, grad, hess, optimizer_kwargs) - return res.x, res.hess + nll = -model.logp() + soln, _ = minimize( + objective=nll, x=x, method=method, jac=jac, hess=hess, optimizer_kwargs=optimizer_kwargs + ) + + get_mode = pytensor.function(inputs, soln) + mode = get_mode(x0, **params) + + # Calculate the value of the Hessian at the mode + # TODO check if we can't pull this out of the soln graph when jac or hess=True + hess_x = pytensor.gradient.hessian(nll, x) + hess = pytensor.function(inputs, hess_x) + + return mode, hess(mode, **params) def fit_laplace( From dac0096872032b4d1d70851b1b701bc9c218be7e Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Mon, 16 Jun 2025 21:48:58 +1000 Subject: [PATCH 04/16] added test case, removed inputs as required arg --- pymc_extras/inference/laplace.py | 77 ++++++++++----- tests/test_laplace.py | 158 +++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+), 22 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 36b31f8b1..6c462373b 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -418,43 +418,76 @@ def sample_laplace_posterior( def find_mode( - inputs: list[TensorVariable], - params: dict, # TODO Would be nice to automatically map this to inputs somehow: {k.name: ... for k in inputs} + x: TensorVariable, + args: dict, + inputs: list[TensorVariable] | None = None, x0: TensorVariable | None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is - x: TensorVariable | None = None, model: pm.Model | None = None, method: minimize_method = "BFGS", - jac: bool = True, - hess: bool = False, + use_jac: bool = True, + use_hess: bool = False, optimizer_kwargs: dict | None = None, ): # TODO Output type is list of same type as x0 model = pm.modelcontext(model) - if x is None: - raise UserWarning( - "Latent Gaussian field x unspecified. Assuming it is the first entry in inputs. Specify which input to obtain the mode over using the input x." - ) - x = inputs[0] - if x0 is None: - # Should return a random numpy array of the same shape as x0 - not sure how to get the shape of x0 - raise NotImplementedError + # if x0 is None: + # #TODO Issue with X not being an RV + # print(model.initial_point()) + + # from pymc.initial_point import make_initial_point_fn + # frozen_model = freeze_dims_and_data(model) + # ipfn = make_initial_point_fn( + # model=model, + # jitter_rvs=set(),#(jitter_rvs), + # return_transformed=True, + # overrides=args, + # ) + + # random_seed = None + # start_dict = ipfn(random_seed) + # vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} + # initial_params = DictToArrayBijection.map( + # {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} + # ) + # print(initial_params) # Minimise negative log likelihood nll = -model.logp() soln, _ = minimize( - objective=nll, x=x, method=method, jac=jac, hess=hess, optimizer_kwargs=optimizer_kwargs + objective=nll, + x=x, + method=method, + jac=use_jac, + hess=use_hess, + optimizer_kwargs=optimizer_kwargs, ) - get_mode = pytensor.function(inputs, soln) - mode = get_mode(x0, **params) - - # Calculate the value of the Hessian at the mode - # TODO check if we can't pull this out of the soln graph when jac or hess=True - hess_x = pytensor.gradient.hessian(nll, x) - hess = pytensor.function(inputs, hess_x) + # Get input variables + # TODO issue when this is nll + if inputs is None: + inputs = [ + pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0] + for var in args + ] + for i, var in enumerate(inputs): + try: + inputs[i] = model.rvs_to_values[var] + except KeyError: + pass + inputs.insert(0, x) + + # Obtain the Hessian (re-use graph if already computed in minimize) + if use_hess: + hess = soln.owner.op.inner_outputs[-1] + hess = pytensor.graph.replace.graph_replace( + hess, {x: soln} + ) # TODO: x here is 'beta', soln is a MinimizeOp. There's no instance of MinimizeOp in the hessian graph + else: + hess = pytensor.gradient.hessian(nll, x) - return mode, hess(mode, **params) + get_mode_and_hessian = pytensor.function(inputs, [soln, hess]) + return get_mode_and_hessian(x0, **args) def fit_laplace( diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 8f7a4c017..3098d7065 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -15,6 +15,7 @@ import numpy as np import pymc as pm +import pytensor as pt import pytest import pymc_extras as pmx @@ -279,3 +280,160 @@ def test_laplace_scalar(): assert idata_laplace.fit.covariance_matrix.shape == (1, 1) np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) + + +def test_find_mode(): + k = 10 + N = 10000 + y = pt.vector("y", dtype="int64") + X = pt.matrix("X", shape=(N, k)) + + # Pre-commit did this. Quite ugly. Should compute hess in code rather than storing a hardcoded array. + true_hess = np.array( + [ + [ + 2.50100000e03, + -1.78838742e00, + 1.59484217e01, + -9.78343803e00, + 2.86125467e01, + -7.38071788e00, + -4.97729126e01, + 3.53243810e01, + 1.69071769e01, + -1.30755942e01, + ], + [ + -1.78838742e00, + 2.54687995e03, + 8.99456512e-02, + -1.33603390e01, + -2.37641179e01, + 4.57780742e01, + -1.22640681e01, + 2.70879664e01, + 4.04435512e01, + 2.08826556e00, + ], + [ + 1.59484217e01, + 8.99456512e-02, + 2.46908384e03, + -1.80358232e01, + 1.14131535e01, + 2.21632317e01, + 1.25443469e00, + 1.50344618e01, + -3.59940488e01, + -1.05191328e01, + ], + [ + -9.78343803e00, + -1.33603390e01, + -1.80358232e01, + 2.50546496e03, + 3.27545028e01, + -3.33517501e01, + -2.68735672e01, + -2.69114305e01, + -1.20464337e01, + 9.02338622e00, + ], + [ + 2.86125467e01, + -2.37641179e01, + 1.14131535e01, + 3.27545028e01, + 2.49959736e03, + -3.98220135e00, + -4.09495199e00, + -1.51115257e01, + -5.77436126e01, + -2.98600447e00, + ], + [ + -7.38071788e00, + 4.57780742e01, + 2.21632317e01, + -3.33517501e01, + -3.98220135e00, + 2.48169432e03, + -1.26885014e01, + -3.53524089e01, + 5.89656794e00, + 1.67164400e01, + ], + [ + -4.97729126e01, + -1.22640681e01, + 1.25443469e00, + -2.68735672e01, + -4.09495199e00, + -1.26885014e01, + 2.47216241e03, + 8.16935659e00, + -4.89399152e01, + -1.11646138e01, + ], + [ + 3.53243810e01, + 2.70879664e01, + 1.50344618e01, + -2.69114305e01, + -1.51115257e01, + -3.53524089e01, + 8.16935659e00, + 2.52940405e03, + 3.07751540e00, + -8.60023392e00, + ], + [ + 1.69071769e01, + 4.04435512e01, + -3.59940488e01, + -1.20464337e01, + -5.77436126e01, + 5.89656794e00, + -4.89399152e01, + 3.07751540e00, + 2.49452594e03, + 6.06984410e01, + ], + [ + -1.30755942e01, + 2.08826556e00, + -1.05191328e01, + 9.02338622e00, + -2.98600447e00, + 1.67164400e01, + -1.11646138e01, + -8.60023392e00, + 6.06984410e01, + 2.49290175e03, + ], + ] + ) + + with pm.Model() as model: + beta = pm.MvNormal("beta", mu=np.zeros(k), cov=np.identity(k), shape=(k,)) + p = pm.math.invlogit(beta @ X.T) + y = pm.Bernoulli("y", p) + + rng = np.random.default_rng(123) + Xval = rng.normal(size=(10000, 9)) + Xval = np.c_[np.ones(10000), Xval] + + true_beta = rng.normal(scale=0.1, size=(10,)) + true_p = pm.math.invlogit(Xval @ true_beta).eval() + ynum = rng.binomial(1, true_p) + + beta_val = model.rvs_to_values[beta] + x0 = np.zeros(k) + args = {"y": ynum, "X": Xval} + + beta_mode, beta_hess = pmx.inference.laplace.find_mode( + x=beta_val, x0=x0, args=args, method="BFGS", optimizer_kwargs={"tol": 1e-8} + ) + + np.testing.assert_allclose(beta_mode, true_beta, atol=0.1, rtol=0.1) + np.testing.assert_allclose(beta_hess, true_hess, atol=0.1, rtol=0.1) From e7b22ac42f8cf69777d4c4e4582bbbe8850de3f8 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Mon, 16 Jun 2025 21:55:38 +1000 Subject: [PATCH 05/16] imported find_mode into test cases --- tests/test_laplace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 3098d7065..1622b525a 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -22,6 +22,7 @@ from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( + find_mode, fit_laplace, fit_mvn_at_MAP, sample_laplace_posterior, @@ -431,7 +432,7 @@ def test_find_mode(): x0 = np.zeros(k) args = {"y": ynum, "X": Xval} - beta_mode, beta_hess = pmx.inference.laplace.find_mode( + beta_mode, beta_hess = find_mode( x=beta_val, x0=x0, args=args, method="BFGS", optimizer_kwargs={"tol": 1e-8} ) From b7eceb40c249992acebded25996f1cbaa9dcf632 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Tue, 17 Jun 2025 17:14:13 +1000 Subject: [PATCH 06/16] made TODOs more verbose --- pymc_extras/inference/laplace.py | 109 ++++++++++++++++++++++--------- tests/test_laplace.py | 2 +- 2 files changed, 80 insertions(+), 31 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 6c462373b..0d28de0a0 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -39,7 +39,7 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorLike, TensorVariable from pytensor.tensor.optimize import minimize from scipy import stats @@ -420,28 +420,60 @@ def sample_laplace_posterior( def find_mode( x: TensorVariable, args: dict, - inputs: list[TensorVariable] | None = None, - x0: TensorVariable - | None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is + x0: TensorLike | None = None, model: pm.Model | None = None, method: minimize_method = "BFGS", use_jac: bool = True, - use_hess: bool = False, + use_hess: bool = False, # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine. optimizer_kwargs: dict | None = None, -): # TODO Output type is list of same type as x0 +) -> list[TensorLike]: + """ + Estimates the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. + + Parameters + ---------- + x: TensorVariable + The parameter with which to minimize wrt (that is, find the mode in x). + args: dict + A dictionary of the form {tensorvariable_name: TensorLike}, where tensorvariable_name is the (exact) name of TensorVariable which is to be provided some numerical value. Same usage as args in scipy.optimize.minimize. + x0: TensorLike + Initial guess for the mode (in x). Initialised over a uniform distribution if unspecified. + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute and store the Jacobian. + use_hess: bool + If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + mu: TensorLike + The mode of the model. + hess: + Hessian evalulated at mu. + """ model = pm.modelcontext(model) - # if x0 is None: - # #TODO Issue with X not being an RV - # print(model.initial_point()) + # # TODO I would like to generate a random initialisation for x0 if set to None. Ideally this would be done using something like np.random.rand(x.shape), however I don't believe x.shape is + # # immediately accessible in pytensor. model.initial_point() throws the following error: + + # # MissingInputError: Input 0 (X) of the graph (indices start from 0), used to compute Transpose{axes=[1, 0]}(X), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error. + # # Instead I've tried to follow what find_MAP does (below), but this doesn't really get me anywhere unfortunately. + # # if x0 is None: + # # Yes ik this is here, just for debugging purposes # from pymc.initial_point import make_initial_point_fn + # frozen_model = freeze_dims_and_data(model) # ipfn = make_initial_point_fn( - # model=model, - # jitter_rvs=set(),#(jitter_rvs), + # model=frozen_model, + # jitter_rvs=set(), # (jitter_rvs), # return_transformed=True, - # overrides=args, + # overrides={x.name: x0}, # x0 is here for debugging purposes # ) # random_seed = None @@ -450,6 +482,7 @@ def find_mode( # initial_params = DictToArrayBijection.map( # {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} # ) + # # Printing this should return {'name of x TensorVariable': initialised values for x} # print(initial_params) # Minimise negative log likelihood @@ -463,30 +496,46 @@ def find_mode( optimizer_kwargs=optimizer_kwargs, ) - # Get input variables - # TODO issue when this is nll - if inputs is None: - inputs = [ - pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0] - for var in args - ] - for i, var in enumerate(inputs): - try: - inputs[i] = model.rvs_to_values[var] - except KeyError: - pass - inputs.insert(0, x) + # TODO To prevent needing to user to pass in the TensorVariables alongside their names (e.g. args = {'X': [X, [0,0]], 'beta': [beta_val, [1]], ...}) the codeblock below digs up the + # TensorVariables associated with the names listed in args from the graph. The sensible graph to use here would be nll, because that is what is going into the minimize function, however doing + # so results in the following error: + # + # TypeError: TensorType does not support iteration. + # Did you pass a PyTensor variable to a function that expects a list? + # Maybe you are using builtins.sum instead of pytensor.tensor.sum? + # + # Using model.basic_RVs[1] instead works, but note that this is a hardcoded fix because the model I'm testing on happens to have all of the relevant TensorVariables in the graph of model.basic_RVs[1], + # but this isn't true in general. + + # Get arg TensorVariables + arg_tensorvars = [ + pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0] + # pytensor.graph.basic.get_var_by_name(nll, target_var_id=var)[0] + for var in args + ] + for i, var in enumerate(arg_tensorvars): + try: + arg_tensorvars[i] = model.rvs_to_values[var] + except KeyError: + pass + arg_tensorvars.insert(0, x) + + # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: + # + # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). + # + # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? # Obtain the Hessian (re-use graph if already computed in minimize) if use_hess: - hess = soln.owner.op.inner_outputs[-1] - hess = pytensor.graph.replace.graph_replace( - hess, {x: soln} - ) # TODO: x here is 'beta', soln is a MinimizeOp. There's no instance of MinimizeOp in the hessian graph + mode, _, hess = ( + soln.owner.op.inner_outputs + ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging + hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) else: hess = pytensor.gradient.hessian(nll, x) - get_mode_and_hessian = pytensor.function(inputs, [soln, hess]) + get_mode_and_hessian = pytensor.function(arg_tensorvars, [soln, hess]) return get_mode_and_hessian(x0, **args) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 1622b525a..be59596fd 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -289,7 +289,7 @@ def test_find_mode(): y = pt.vector("y", dtype="int64") X = pt.matrix("X", shape=(N, k)) - # Pre-commit did this. Quite ugly. Should compute hess in code rather than storing a hardcoded array. + # TODO Pre-commit formatted it like this. Quite ugly. Should compute hess in code rather than storing a hardcoded array. true_hess = np.array( [ [ From 40f27e0a7b4e1e89f26078fc708c1b90971b1a91 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Thu, 19 Jun 2025 20:52:16 +1000 Subject: [PATCH 07/16] refactor: find_mode_and_hess returns a function --- pymc_extras/inference/laplace.py | 76 ++----------- tests/test_laplace.py | 176 +++++-------------------------- 2 files changed, 34 insertions(+), 218 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 0d28de0a0..94eb08fba 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -15,6 +15,7 @@ import logging +from collections.abc import Callable from functools import reduce from importlib.util import find_spec from itertools import product @@ -39,7 +40,7 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames -from pytensor.tensor import TensorLike, TensorVariable +from pytensor.tensor import TensorVariable from pytensor.tensor.optimize import minimize from scipy import stats @@ -417,27 +418,21 @@ def sample_laplace_posterior( return idata -def find_mode( +def find_mode_and_hess( x: TensorVariable, - args: dict, - x0: TensorLike | None = None, model: pm.Model | None = None, method: minimize_method = "BFGS", use_jac: bool = True, use_hess: bool = False, # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine. optimizer_kwargs: dict | None = None, -) -> list[TensorLike]: +) -> Callable: """ - Estimates the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. + Returns a function to estimate the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. Parameters ---------- x: TensorVariable The parameter with which to minimize wrt (that is, find the mode in x). - args: dict - A dictionary of the form {tensorvariable_name: TensorLike}, where tensorvariable_name is the (exact) name of TensorVariable which is to be provided some numerical value. Same usage as args in scipy.optimize.minimize. - x0: TensorLike - Initial guess for the mode (in x). Initialised over a uniform distribution if unspecified. model: Model PyMC model to use. method: minimize_method @@ -451,40 +446,11 @@ def find_mode( Returns ------- - mu: TensorLike - The mode of the model. - hess: - Hessian evalulated at mu. + f: Callable + A function which accepts the values of the model RVs as args and returns [mu, hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. """ model = pm.modelcontext(model) - # # TODO I would like to generate a random initialisation for x0 if set to None. Ideally this would be done using something like np.random.rand(x.shape), however I don't believe x.shape is - # # immediately accessible in pytensor. model.initial_point() throws the following error: - - # # MissingInputError: Input 0 (X) of the graph (indices start from 0), used to compute Transpose{axes=[1, 0]}(X), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error. - - # # Instead I've tried to follow what find_MAP does (below), but this doesn't really get me anywhere unfortunately. - # # if x0 is None: - # # Yes ik this is here, just for debugging purposes - # from pymc.initial_point import make_initial_point_fn - - # frozen_model = freeze_dims_and_data(model) - # ipfn = make_initial_point_fn( - # model=frozen_model, - # jitter_rvs=set(), # (jitter_rvs), - # return_transformed=True, - # overrides={x.name: x0}, # x0 is here for debugging purposes - # ) - - # random_seed = None - # start_dict = ipfn(random_seed) - # vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} - # initial_params = DictToArrayBijection.map( - # {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} - # ) - # # Printing this should return {'name of x TensorVariable': initialised values for x} - # print(initial_params) - # Minimise negative log likelihood nll = -model.logp() soln, _ = minimize( @@ -496,30 +462,6 @@ def find_mode( optimizer_kwargs=optimizer_kwargs, ) - # TODO To prevent needing to user to pass in the TensorVariables alongside their names (e.g. args = {'X': [X, [0,0]], 'beta': [beta_val, [1]], ...}) the codeblock below digs up the - # TensorVariables associated with the names listed in args from the graph. The sensible graph to use here would be nll, because that is what is going into the minimize function, however doing - # so results in the following error: - # - # TypeError: TensorType does not support iteration. - # Did you pass a PyTensor variable to a function that expects a list? - # Maybe you are using builtins.sum instead of pytensor.tensor.sum? - # - # Using model.basic_RVs[1] instead works, but note that this is a hardcoded fix because the model I'm testing on happens to have all of the relevant TensorVariables in the graph of model.basic_RVs[1], - # but this isn't true in general. - - # Get arg TensorVariables - arg_tensorvars = [ - pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0] - # pytensor.graph.basic.get_var_by_name(nll, target_var_id=var)[0] - for var in args - ] - for i, var in enumerate(arg_tensorvars): - try: - arg_tensorvars[i] = model.rvs_to_values[var] - except KeyError: - pass - arg_tensorvars.insert(0, x) - # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: # # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). @@ -535,8 +477,8 @@ def find_mode( else: hess = pytensor.gradient.hessian(nll, x) - get_mode_and_hessian = pytensor.function(arg_tensorvars, [soln, hess]) - return get_mode_and_hessian(x0, **args) + args = model.continuous_value_vars + model.discrete_value_vars + return pytensor.function(args, [soln, hess]) def fit_laplace( diff --git a/tests/test_laplace.py b/tests/test_laplace.py index be59596fd..c59334812 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -15,14 +15,13 @@ import numpy as np import pymc as pm -import pytensor as pt import pytest import pymc_extras as pmx from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( - find_mode, + find_mode_and_hess, fit_laplace, fit_mvn_at_MAP, sample_laplace_posterior, @@ -283,158 +282,33 @@ def test_laplace_scalar(): np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) -def test_find_mode(): - k = 10 - N = 10000 - y = pt.vector("y", dtype="int64") - X = pt.matrix("X", shape=(N, k)) - - # TODO Pre-commit formatted it like this. Quite ugly. Should compute hess in code rather than storing a hardcoded array. - true_hess = np.array( - [ - [ - 2.50100000e03, - -1.78838742e00, - 1.59484217e01, - -9.78343803e00, - 2.86125467e01, - -7.38071788e00, - -4.97729126e01, - 3.53243810e01, - 1.69071769e01, - -1.30755942e01, - ], - [ - -1.78838742e00, - 2.54687995e03, - 8.99456512e-02, - -1.33603390e01, - -2.37641179e01, - 4.57780742e01, - -1.22640681e01, - 2.70879664e01, - 4.04435512e01, - 2.08826556e00, - ], - [ - 1.59484217e01, - 8.99456512e-02, - 2.46908384e03, - -1.80358232e01, - 1.14131535e01, - 2.21632317e01, - 1.25443469e00, - 1.50344618e01, - -3.59940488e01, - -1.05191328e01, - ], - [ - -9.78343803e00, - -1.33603390e01, - -1.80358232e01, - 2.50546496e03, - 3.27545028e01, - -3.33517501e01, - -2.68735672e01, - -2.69114305e01, - -1.20464337e01, - 9.02338622e00, - ], - [ - 2.86125467e01, - -2.37641179e01, - 1.14131535e01, - 3.27545028e01, - 2.49959736e03, - -3.98220135e00, - -4.09495199e00, - -1.51115257e01, - -5.77436126e01, - -2.98600447e00, - ], - [ - -7.38071788e00, - 4.57780742e01, - 2.21632317e01, - -3.33517501e01, - -3.98220135e00, - 2.48169432e03, - -1.26885014e01, - -3.53524089e01, - 5.89656794e00, - 1.67164400e01, - ], - [ - -4.97729126e01, - -1.22640681e01, - 1.25443469e00, - -2.68735672e01, - -4.09495199e00, - -1.26885014e01, - 2.47216241e03, - 8.16935659e00, - -4.89399152e01, - -1.11646138e01, - ], - [ - 3.53243810e01, - 2.70879664e01, - 1.50344618e01, - -2.69114305e01, - -1.51115257e01, - -3.53524089e01, - 8.16935659e00, - 2.52940405e03, - 3.07751540e00, - -8.60023392e00, - ], - [ - 1.69071769e01, - 4.04435512e01, - -3.59940488e01, - -1.20464337e01, - -5.77436126e01, - 5.89656794e00, - -4.89399152e01, - 3.07751540e00, - 2.49452594e03, - 6.06984410e01, - ], - [ - -1.30755942e01, - 2.08826556e00, - -1.05191328e01, - 9.02338622e00, - -2.98600447e00, - 1.67164400e01, - -1.11646138e01, - -8.60023392e00, - 6.06984410e01, - 2.49290175e03, - ], - ] - ) +def test_find_mode_and_hess(): + rng = np.random.default_rng(42) + n = 100 + sigma_obs = rng.random() + sigma_mu = rng.random() - with pm.Model() as model: - beta = pm.MvNormal("beta", mu=np.zeros(k), cov=np.identity(k), shape=(k,)) - p = pm.math.invlogit(beta @ X.T) - y = pm.Bernoulli("y", p) + coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(n)} + with pm.Model(coords=coords) as model: + obs_val = rng.normal(loc=3, scale=1.5, size=(n, 3)) - rng = np.random.default_rng(123) - Xval = rng.normal(size=(10000, 9)) - Xval = np.c_[np.ones(10000), Xval] + mu = pm.Normal("mu", mu=1, sigma=sigma_mu, dims=["city"]) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma_obs, + observed=obs_val, + dims=["obs_idx", "city"], + ) - true_beta = rng.normal(scale=0.1, size=(10,)) - true_p = pm.math.invlogit(Xval @ true_beta).eval() - ynum = rng.binomial(1, true_p) + get_mode_and_hessian = find_mode_and_hess( + use_hess=False, x=model.rvs_to_values[mu], method="BFGS", optimizer_kwargs={"tol": 1e-8} + ) - beta_val = model.rvs_to_values[beta] - x0 = np.zeros(k) - args = {"y": ynum, "X": Xval} + mode, hess = get_mode_and_hessian(**{"mu": [1, 1, 1]}) - beta_mode, beta_hess = find_mode( - x=beta_val, x0=x0, args=args, method="BFGS", optimizer_kwargs={"tol": 1e-8} - ) + true_mode = obs_val.mean(axis=0) + true_hess = np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3)) - np.testing.assert_allclose(beta_mode, true_beta, atol=0.1, rtol=0.1) - np.testing.assert_allclose(beta_hess, true_hess, atol=0.1, rtol=0.1) + np.testing.assert_allclose(mode, true_mode, atol=0.1, rtol=0.1) + np.testing.assert_allclose(hess, true_hess, atol=0.1, rtol=0.1) From a1292eabcc01b3333be013d7d2f575245554b462 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Mon, 23 Jun 2025 00:28:57 +1000 Subject: [PATCH 08/16] WIP: Should find root of conditional_gaussian_approx not minimize nll --- pymc_extras/inference/laplace.py | 189 ++++++++++++++++++++----------- tests/test_laplace.py | 16 +-- 2 files changed, 133 insertions(+), 72 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 94eb08fba..973d47c02 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -29,7 +29,7 @@ import xarray as xr from arviz import dict_to_dataset -from better_optimize.constants import minimize_method +from better_optimize.constants import minimize_method, root_method from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -41,7 +41,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames from pytensor.tensor import TensorVariable -from pytensor.tensor.optimize import minimize +from pytensor.tensor.optimize import root from scipy import stats from pymc_extras.inference.find_map import ( @@ -55,6 +55,128 @@ _log = logging.getLogger(__name__) +def find_mode_jac_hess( + x: TensorVariable, # Should be vector specifically + Q: TensorVariable, # Matrix # TODO tensorinv doesn't have grad implemented yet + mu: TensorVariable, # Vector + model: pm.Model | None = None, + method: root_method = "hybr", + use_jac: bool = True, + # use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> Callable: + """ + Returns a function to estimate the mode and both the first and second derivatives of a model at that point by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. + + Parameters + ---------- + x: TensorVariable + The parameter with which to minimize wrt (that is, find the mode in x). + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute and store the Jacobian. + use_hess: bool + If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + f: Callable + A function which accepts the values of the model RVs as args and returns [mu, jac(mu) hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. + """ + model = pm.modelcontext(model) + + # f = log(p(y | x, params)) + f = model.logp() + jac = pytensor.gradient.grad(f, x) + hess = pytensor.gradient.jacobian(jac.flatten(), x) + + # Component of log(p(x | y, params)) which depends on x (for rootfinding) + conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x) + + x0, _ = root( + equations=pt.stack([conditional_gaussian_approx]), + variables=x, + method=method, + jac=use_jac, + optimizer_kwargs=optimizer_kwargs, + ) + + # require f'(x0) and f''(x0) for Laplace approx + jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) + hess = pytensor.graph.replace.graph_replace( + hess, {x: x0} + ) # Possibly unecessary because jac already does this replace + + # Full log(p(x | y, params)) + _, logdetQ = pt.nlinalg.slogdet(Q) + conditional_gaussian_approx = ( + -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ + ) # TODO does doing this change the graph in root before if changed before it's compiled? + + args = model.continuous_value_vars + model.discrete_value_vars + return pytensor.function( + args, [x0, conditional_gaussian_approx] + ) # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x + + # Minimise negative log likelihood + # nll = -model.logp() + # soln, _ = minimize( + # objective=nll, + # x=x, + # method=method, + # jac=use_jac, + # hess=use_hess, + # optimizer_kwargs=optimizer_kwargs, + # ) + + # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: + # + # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). + # + # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? + + # Obtain the Hessian (re-use graph if already computed in minimize) + # if use_hess: + # mode, _, hess = ( + # soln.owner.op.inner_outputs + # ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging + # hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) + # else: + # hess = pytensor.gradient.hessian(nll, x) + + # Obtain the gradient and Hessian (re-use graphs if already computed in minimize) + # res = soln.owner.op.inner_outputs + # mode = res[0] + + # print(res) + + # if use_jac: + # # jac = pytensor.gradient.grad(nll, x) + # jac = res.pop(1) + # else: + # jac = pytensor.gradient.grad(nll, x) + # jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) + + # print(x) + # # jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) + + # jac = -jac # We subsequently want the gradients wrt log(p(y | x)) rather than the negative of this (nll) + + # if use_hess: + # hess = res.pop(1) + # else: + # hess = pytensor.gradient.jacobian(jac.flatten(), soln) + # # hess = pytensor.graph.replace.graph_replace(hess, {x: soln}) + + # args = model.continuous_value_vars + model.discrete_value_vars + # return pytensor.function(args, [soln, jac, hess]) + + def laplace_draws_to_inferencedata( posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None ) -> az.InferenceData: @@ -418,69 +540,6 @@ def sample_laplace_posterior( return idata -def find_mode_and_hess( - x: TensorVariable, - model: pm.Model | None = None, - method: minimize_method = "BFGS", - use_jac: bool = True, - use_hess: bool = False, # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine. - optimizer_kwargs: dict | None = None, -) -> Callable: - """ - Returns a function to estimate the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. - - Parameters - ---------- - x: TensorVariable - The parameter with which to minimize wrt (that is, find the mode in x). - model: Model - PyMC model to use. - method: minimize_method - Which minimization algorithm to use. - use_jac: bool - If true, the minimizer will compute and store the Jacobian. - use_hess: bool - If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). - optimizer_kwargs: dict - Kwargs to pass to scipy.optimize.minimize. - - Returns - ------- - f: Callable - A function which accepts the values of the model RVs as args and returns [mu, hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. - """ - model = pm.modelcontext(model) - - # Minimise negative log likelihood - nll = -model.logp() - soln, _ = minimize( - objective=nll, - x=x, - method=method, - jac=use_jac, - hess=use_hess, - optimizer_kwargs=optimizer_kwargs, - ) - - # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: - # - # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). - # - # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? - - # Obtain the Hessian (re-use graph if already computed in minimize) - if use_hess: - mode, _, hess = ( - soln.owner.op.inner_outputs - ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging - hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) - else: - hess = pytensor.gradient.hessian(nll, x) - - args = model.continuous_value_vars + model.discrete_value_vars - return pytensor.function(args, [soln, hess]) - - def fit_laplace( optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", *, diff --git a/tests/test_laplace.py b/tests/test_laplace.py index c59334812..2098d7a6c 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -21,7 +21,7 @@ from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( - find_mode_and_hess, + find_mode_jac_hess, fit_laplace, fit_mvn_at_MAP, sample_laplace_posterior, @@ -282,17 +282,19 @@ def test_laplace_scalar(): np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) -def test_find_mode_and_hess(): +def test_find_mode_jac_hess(): rng = np.random.default_rng(42) n = 100 sigma_obs = rng.random() sigma_mu = rng.random() + true_mu = rng.random() + mu_val = rng.random() coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(n)} with pm.Model(coords=coords) as model: - obs_val = rng.normal(loc=3, scale=1.5, size=(n, 3)) + obs_val = rng.normal(loc=true_mu, scale=1.5, size=(n, 3)) - mu = pm.Normal("mu", mu=1, sigma=sigma_mu, dims=["city"]) + mu = pm.Normal("mu", mu=mu_val, sigma=sigma_mu, dims=["city"]) obs = pm.Normal( "obs", mu=mu, @@ -301,14 +303,14 @@ def test_find_mode_and_hess(): dims=["obs_idx", "city"], ) - get_mode_and_hessian = find_mode_and_hess( + get_mode_and_hessian = find_mode_jac_hess( use_hess=False, x=model.rvs_to_values[mu], method="BFGS", optimizer_kwargs={"tol": 1e-8} ) - mode, hess = get_mode_and_hessian(**{"mu": [1, 1, 1]}) + mode, jac, hess = get_mode_and_hessian(mu=[1, 1, 1]) true_mode = obs_val.mean(axis=0) - true_hess = np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3)) + true_hess = -np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3)) np.testing.assert_allclose(mode, true_mode, atol=0.1, rtol=0.1) np.testing.assert_allclose(hess, true_hess, atol=0.1, rtol=0.1) From bb50b23f486bb48655de04bb623bced338e4acc7 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Tue, 24 Jun 2025 19:46:04 +1000 Subject: [PATCH 09/16] in a working state --- pymc_extras/inference/laplace.py | 33 ++++++++++-------------- tests/test_laplace.py | 44 ++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 973d47c02..9f0e6af33 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -29,7 +29,7 @@ import xarray as xr from arviz import dict_to_dataset -from better_optimize.constants import minimize_method, root_method +from better_optimize.constants import minimize_method from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -41,7 +41,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames from pytensor.tensor import TensorVariable -from pytensor.tensor.optimize import root +from pytensor.tensor.optimize import minimize from scipy import stats from pymc_extras.inference.find_map import ( @@ -55,14 +55,15 @@ _log = logging.getLogger(__name__) -def find_mode_jac_hess( +def get_conditional_gaussian_approximation( x: TensorVariable, # Should be vector specifically Q: TensorVariable, # Matrix # TODO tensorinv doesn't have grad implemented yet mu: TensorVariable, # Vector + args: list[TensorVariable] | None = None, model: pm.Model | None = None, - method: root_method = "hybr", + method: minimize_method = "BFGS", use_jac: bool = True, - # use_hess: bool = False, + use_hess: bool = False, optimizer_kwargs: dict | None = None, ) -> Callable: """ @@ -98,11 +99,12 @@ def find_mode_jac_hess( # Component of log(p(x | y, params)) which depends on x (for rootfinding) conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x) - x0, _ = root( - equations=pt.stack([conditional_gaussian_approx]), - variables=x, + x0, _ = minimize( + objective=-conditional_gaussian_approx, + x=x, method=method, jac=use_jac, + hess=use_hess, optimizer_kwargs=optimizer_kwargs, ) @@ -118,22 +120,13 @@ def find_mode_jac_hess( -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ ) # TODO does doing this change the graph in root before if changed before it's compiled? - args = model.continuous_value_vars + model.discrete_value_vars + if args is None: + args = model.continuous_value_vars + model.discrete_value_vars + return pytensor.function( args, [x0, conditional_gaussian_approx] ) # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x - # Minimise negative log likelihood - # nll = -model.logp() - # soln, _ = minimize( - # objective=nll, - # x=x, - # method=method, - # jac=use_jac, - # hess=use_hess, - # optimizer_kwargs=optimizer_kwargs, - # ) - # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: # # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 2098d7a6c..bb24835ca 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -282,6 +282,50 @@ def test_laplace_scalar(): np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) +# rng = np.random.default_rng(42) +# n = 100 +# d = 3 +# k = 10 +# mu_true = 3*np.ones(d) #rng.random(d) +# cov_true = np.diag(np.ones(d))#rng.random(d) +# Q_val = np.diag(np.ones(d))#np.diag(rng.random(d)) + +# with pm.Model() as model: +# y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) + +# mu_param = pm.MvNormal("mu_param", mu=3*np.ones(d), cov=np.diag(np.ones(d))) +# cov_param = pm.MvNormal("cov_param", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) +# Q = pm.MvNormal("Q", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) + +# x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) + +# y = pm.MvNormal( +# "y", +# mu=x, +# cov=cov_param.reshape((d,d)), +# observed=y_obs, +# ) + +# # logp(x | y, params) +# cga = get_conditional_gaussian_approximation( +# x=model.rvs_to_values[x], +# Q=Q.reshape((d,d)), +# mu=mu_param, +# # args=[model.rvs_to_values[x], Q, model.rvs_to_values[mu_param], model.rvs_to_values[cov_param]], +# optimizer_kwargs={"tol": 1e-9} +# ) + +# x = 4*np.array([1,1,1]) +# sigma_inv = np.linalg.inv(np.diag(np.ones(d))) +# x0 = np.linalg.inv(n*sigma_inv - Q_val) @ (sigma_inv@y_obs.sum(axis=0) - Q_val@x) +# print(x0) + +# x = 4*np.array([1,1,1]) + +# res = cga(x=x, mu_param=x, cov_param=np.diag(np.ones(d)).flatten(), Q=Q_val.flatten()) +# print(res) + + def test_find_mode_jac_hess(): rng = np.random.default_rng(42) n = 100 From 84a86eeaf4f8637f43c95e2eb67da57443c4a075 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Tue, 24 Jun 2025 20:50:01 +1000 Subject: [PATCH 10/16] test case passes --- pymc_extras/inference/laplace.py | 42 ----------- tests/test_laplace.py | 119 ++++++++++++++----------------- 2 files changed, 53 insertions(+), 108 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 9f0e6af33..af62cbd18 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -127,48 +127,6 @@ def get_conditional_gaussian_approximation( args, [x0, conditional_gaussian_approx] ) # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x - # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: - # - # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). - # - # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? - - # Obtain the Hessian (re-use graph if already computed in minimize) - # if use_hess: - # mode, _, hess = ( - # soln.owner.op.inner_outputs - # ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging - # hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) - # else: - # hess = pytensor.gradient.hessian(nll, x) - - # Obtain the gradient and Hessian (re-use graphs if already computed in minimize) - # res = soln.owner.op.inner_outputs - # mode = res[0] - - # print(res) - - # if use_jac: - # # jac = pytensor.gradient.grad(nll, x) - # jac = res.pop(1) - # else: - # jac = pytensor.gradient.grad(nll, x) - # jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) - - # print(x) - # # jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) - - # jac = -jac # We subsequently want the gradients wrt log(p(y | x)) rather than the negative of this (nll) - - # if use_hess: - # hess = res.pop(1) - # else: - # hess = pytensor.gradient.jacobian(jac.flatten(), soln) - # # hess = pytensor.graph.replace.graph_replace(hess, {x: soln}) - - # args = model.continuous_value_vars + model.discrete_value_vars - # return pytensor.function(args, [soln, jac, hess]) - def laplace_draws_to_inferencedata( posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None diff --git a/tests/test_laplace.py b/tests/test_laplace.py index bb24835ca..dc7fb78d7 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -21,9 +21,9 @@ from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( - find_mode_jac_hess, fit_laplace, fit_mvn_at_MAP, + get_conditional_gaussian_approximation, sample_laplace_posterior, ) @@ -282,79 +282,66 @@ def test_laplace_scalar(): np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) -# rng = np.random.default_rng(42) -# n = 100 -# d = 3 -# k = 10 -# mu_true = 3*np.ones(d) #rng.random(d) -# cov_true = np.diag(np.ones(d))#rng.random(d) -# Q_val = np.diag(np.ones(d))#np.diag(rng.random(d)) - -# with pm.Model() as model: -# y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) - -# mu_param = pm.MvNormal("mu_param", mu=3*np.ones(d), cov=np.diag(np.ones(d))) -# cov_param = pm.MvNormal("cov_param", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) -# Q = pm.MvNormal("Q", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) - -# x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) - -# y = pm.MvNormal( -# "y", -# mu=x, -# cov=cov_param.reshape((d,d)), -# observed=y_obs, -# ) - -# # logp(x | y, params) -# cga = get_conditional_gaussian_approximation( -# x=model.rvs_to_values[x], -# Q=Q.reshape((d,d)), -# mu=mu_param, -# # args=[model.rvs_to_values[x], Q, model.rvs_to_values[mu_param], model.rvs_to_values[cov_param]], -# optimizer_kwargs={"tol": 1e-9} -# ) - -# x = 4*np.array([1,1,1]) -# sigma_inv = np.linalg.inv(np.diag(np.ones(d))) -# x0 = np.linalg.inv(n*sigma_inv - Q_val) @ (sigma_inv@y_obs.sum(axis=0) - Q_val@x) -# print(x0) - -# x = 4*np.array([1,1,1]) +def test_get_conditional_gaussian_approximation(): + rng = np.random.default_rng(42) + n = 100 + d = 3 + mu_true = rng.random(d) + cov_true = np.diag(rng.random(d)) + Q_val = np.diag(rng.random(d)) -# res = cga(x=x, mu_param=x, cov_param=np.diag(np.ones(d)).flatten(), Q=Q_val.flatten()) -# print(res) + sigma_params = rng.random(d**2).reshape((d, d)) + x_val = rng.random(d) + mu_val = rng.random(d) + with pm.Model() as model: + y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) -def test_find_mode_jac_hess(): - rng = np.random.default_rng(42) - n = 100 - sigma_obs = rng.random() - sigma_mu = rng.random() - true_mu = rng.random() - mu_val = rng.random() + mu_param = pm.MvNormal("mu_param", mu=np.zeros(d), cov=np.diag(np.ones(d))) + cov_param = pm.MvNormal("cov_param", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) + Q = pm.MvNormal("Q", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) - coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(n)} - with pm.Model(coords=coords) as model: - obs_val = rng.normal(loc=true_mu, scale=1.5, size=(n, 3)) + x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) - mu = pm.Normal("mu", mu=mu_val, sigma=sigma_mu, dims=["city"]) - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma_obs, - observed=obs_val, - dims=["obs_idx", "city"], + y = pm.MvNormal( + "y", + mu=x, + cov=cov_param.reshape((d, d)), + observed=y_obs, ) - get_mode_and_hessian = find_mode_jac_hess( - use_hess=False, x=model.rvs_to_values[mu], method="BFGS", optimizer_kwargs={"tol": 1e-8} + # logp(x | y, params) + cga = get_conditional_gaussian_approximation( + x=model.rvs_to_values[x], + Q=Q.reshape((d, d)), + mu=mu_param, + optimizer_kwargs={"tol": 1e-25}, ) - mode, jac, hess = get_mode_and_hessian(mu=[1, 1, 1]) + x0, log_x_posterior = cga( + x=x_val, mu_param=mu_val, cov_param=sigma_params.flatten(), Q=Q_val.flatten() + ) - true_mode = obs_val.mean(axis=0) - true_hess = -np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3)) + sigma_inv = np.linalg.inv(sigma_params) - np.testing.assert_allclose(mode, true_mode, atol=0.1, rtol=0.1) - np.testing.assert_allclose(hess, true_hess, atol=0.1, rtol=0.1) + x0_true = np.linalg.inv(n * sigma_inv - 2 * Q_val) @ ( + sigma_inv @ y_obs.sum(axis=0) - 2 * Q_val @ mu_val + ) + + log_x_posterior_true = ( + -0.5 * x_val.T @ (-n * sigma_inv + Q_val) @ x_val + + x_val.T + @ (Q_val @ mu_val - sigma_inv @ (y_obs - x0_true).sum(axis=0) - n * sigma_inv @ x0_true) + + 0.5 * np.log(np.linalg.det(Q_val)) + + -0.5 * sigma_params.flatten().T @ np.diag(np.ones(d**2)) @ sigma_params.flatten() + - ((d**2) / 2) * np.log(2 * np.pi) + - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d**2)))) + + -0.5 * mu_val.T @ np.diag(np.ones(d)) @ mu_val + - (d / 2) * np.log(2 * np.pi) + - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d)))) + + -0.5 * (x_val - mu_val).T @ Q_val @ (x_val - mu_val) + - (d / 2) * np.log(2 * np.pi) + - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d)))) + ) + np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) + np.testing.assert_allclose(log_x_posterior, log_x_posterior_true, atol=0.1, rtol=0.1) From cf23feba46f79388d106de43ddb6f9550918e40f Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Tue, 24 Jun 2025 20:59:52 +1000 Subject: [PATCH 11/16] add nontrivial values in test case --- tests/test_laplace.py | 44 +++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index dc7fb78d7..1c8e06e6e 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -286,20 +286,28 @@ def test_get_conditional_gaussian_approximation(): rng = np.random.default_rng(42) n = 100 d = 3 + mu_true = rng.random(d) cov_true = np.diag(rng.random(d)) Q_val = np.diag(rng.random(d)) + cov_param_val = rng.random(d**2).reshape((d, d)) - sigma_params = rng.random(d**2).reshape((d, d)) x_val = rng.random(d) mu_val = rng.random(d) + mu_mu = rng.random(d) + mu_cov = np.diag(rng.random(d)) + cov_mu = rng.random(d**2) + cov_cov = np.diag(rng.random(d**2)) + Q_mu = rng.random(d**2) + Q_cov = np.diag(rng.random(d**2)) + with pm.Model() as model: y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) - mu_param = pm.MvNormal("mu_param", mu=np.zeros(d), cov=np.diag(np.ones(d))) - cov_param = pm.MvNormal("cov_param", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) - Q = pm.MvNormal("Q", mu=np.zeros(d**2), cov=np.diag(np.ones(d**2))) + mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov) + cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov) + Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) @@ -319,29 +327,33 @@ def test_get_conditional_gaussian_approximation(): ) x0, log_x_posterior = cga( - x=x_val, mu_param=mu_val, cov_param=sigma_params.flatten(), Q=Q_val.flatten() + x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten() ) - sigma_inv = np.linalg.inv(sigma_params) + cov_param_inv = np.linalg.inv(cov_param_val) - x0_true = np.linalg.inv(n * sigma_inv - 2 * Q_val) @ ( - sigma_inv @ y_obs.sum(axis=0) - 2 * Q_val @ mu_val + x0_true = np.linalg.inv(n * cov_param_inv - 2 * Q_val) @ ( + cov_param_inv @ y_obs.sum(axis=0) - 2 * Q_val @ mu_val ) log_x_posterior_true = ( - -0.5 * x_val.T @ (-n * sigma_inv + Q_val) @ x_val + -0.5 * x_val.T @ (-n * cov_param_inv + Q_val) @ x_val + x_val.T - @ (Q_val @ mu_val - sigma_inv @ (y_obs - x0_true).sum(axis=0) - n * sigma_inv @ x0_true) + @ ( + Q_val @ mu_val + - cov_param_inv @ (y_obs - x0_true).sum(axis=0) + - n * cov_param_inv @ x0_true + ) + 0.5 * np.log(np.linalg.det(Q_val)) - + -0.5 * sigma_params.flatten().T @ np.diag(np.ones(d**2)) @ sigma_params.flatten() + + -0.5 * cov_param_val.flatten().T @ np.linalg.inv(cov_cov) @ cov_param_val.flatten() - ((d**2) / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d**2)))) - + -0.5 * mu_val.T @ np.diag(np.ones(d)) @ mu_val + - 0.5 * np.log(np.linalg.det(cov_cov)) + + -0.5 * mu_val.T @ np.linalg.inv(mu_cov) @ mu_val - (d / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d)))) - + -0.5 * (x_val - mu_val).T @ Q_val @ (x_val - mu_val) + - 0.5 * np.log(np.linalg.det(mu_cov)) + + -0.5 * (x_val - mu_val).T @ np.linalg.inv(Q_val) @ (x_val - mu_val) - (d / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(np.diag(np.ones(d)))) + - 0.5 * np.log(np.linalg.det(Q_val)) ) np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) np.testing.assert_allclose(log_x_posterior, log_x_posterior_true, atol=0.1, rtol=0.1) From 3af3fd511d989c8e68ac5c3be179f238a591bb3a Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Thu, 26 Jun 2025 22:03:43 +1000 Subject: [PATCH 12/16] Updated docstrings + comments --- pymc_extras/inference/laplace.py | 35 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index af62cbd18..1800f24ca 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -30,6 +30,7 @@ from arviz import dict_to_dataset from better_optimize.constants import minimize_method +from numpy.typing import ArrayLike from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -56,9 +57,9 @@ def get_conditional_gaussian_approximation( - x: TensorVariable, # Should be vector specifically - Q: TensorVariable, # Matrix # TODO tensorinv doesn't have grad implemented yet - mu: TensorVariable, # Vector + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, args: list[TensorVariable] | None = None, model: pm.Model | None = None, method: minimize_method = "BFGS", @@ -67,27 +68,33 @@ def get_conditional_gaussian_approximation( optimizer_kwargs: dict | None = None, ) -> Callable: """ - Returns a function to estimate the mode and both the first and second derivatives of a model at that point by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. + Returns a function to estimate log(p(x | y, params)) and its mode x0 using the Laplace approximation. Parameters ---------- x: TensorVariable - The parameter with which to minimize wrt (that is, find the mode in x). + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). + Q: TensorVariable | ArrayLike + The precision matrix of the latent field x. + mu: TensorVariable | ArrayLike + The mean of the latent field x. + args: list[TensorVariable] + Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. model: Model PyMC model to use. method: minimize_method Which minimization algorithm to use. use_jac: bool - If true, the minimizer will compute and store the Jacobian. + If true, the minimizer will compute the gradient of log(p(x | y, params)). use_hess: bool - If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). + If true, the minimizer will compute the Hessian log(p(x | y, params)). optimizer_kwargs: dict Kwargs to pass to scipy.optimize.minimize. Returns ------- f: Callable - A function which accepts the values of the model RVs as args and returns [mu, jac(mu) hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. + A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. """ model = pm.modelcontext(model) @@ -99,6 +106,7 @@ def get_conditional_gaussian_approximation( # Component of log(p(x | y, params)) which depends on x (for rootfinding) conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x) + # Maximize log(p(x | y, params)) wrt x x0, _ = minimize( objective=-conditional_gaussian_approx, x=x, @@ -110,22 +118,19 @@ def get_conditional_gaussian_approximation( # require f'(x0) and f''(x0) for Laplace approx jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) - hess = pytensor.graph.replace.graph_replace( - hess, {x: x0} - ) # Possibly unecessary because jac already does this replace + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) # Full log(p(x | y, params)) _, logdetQ = pt.nlinalg.slogdet(Q) conditional_gaussian_approx = ( -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ - ) # TODO does doing this change the graph in root before if changed before it's compiled? + ) if args is None: args = model.continuous_value_vars + model.discrete_value_vars - return pytensor.function( - args, [x0, conditional_gaussian_approx] - ) # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x + # TODO Currently x being passed in as an initial guess for x0 AND then also going to the true value of x + return pytensor.function(args, [x0, conditional_gaussian_approx]) def laplace_draws_to_inferencedata( From 1d53b1aeb9b00f6562848f95ecda0d07a68837dc Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Fri, 27 Jun 2025 14:44:59 +1000 Subject: [PATCH 13/16] made test case more rigorous --- pymc_extras/inference/laplace.py | 18 +++++----- tests/test_laplace.py | 59 +++++++++++++++++--------------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 1800f24ca..17a661792 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -98,17 +98,20 @@ def get_conditional_gaussian_approximation( """ model = pm.modelcontext(model) + if args is None: + args = model.continuous_value_vars + model.discrete_value_vars + # f = log(p(y | x, params)) - f = model.logp() - jac = pytensor.gradient.grad(f, x) + f_x = model.logp() + jac = pytensor.gradient.grad(f_x, x) hess = pytensor.gradient.jacobian(jac.flatten(), x) - # Component of log(p(x | y, params)) which depends on x (for rootfinding) - conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x) + # Component of log(p(x | y, params)) which depends on x + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) # Maximize log(p(x | y, params)) wrt x x0, _ = minimize( - objective=-conditional_gaussian_approx, + objective=-log_x_posterior, x=x, method=method, jac=use_jac, @@ -120,15 +123,12 @@ def get_conditional_gaussian_approximation( jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) - # Full log(p(x | y, params)) + # Full log(p(x | y, params)) using Laplace approximation (up to a constant) _, logdetQ = pt.nlinalg.slogdet(Q) conditional_gaussian_approx = ( -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ ) - if args is None: - args = model.continuous_value_vars + model.discrete_value_vars - # TODO Currently x being passed in as an initial guess for x0 AND then also going to the true value of x return pytensor.function(args, [x0, conditional_gaussian_approx]) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 1c8e06e6e..72ff3e937 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -283,24 +283,37 @@ def test_laplace_scalar(): def test_get_conditional_gaussian_approximation(): - rng = np.random.default_rng(42) - n = 100 - d = 3 + """ + Consider the trivial case of: + y | x ~ N(x, cov_param) + x | param ~ N(mu_param, Q^-1) + + cov_param ~ N(cov_mu, cov_cov) + mu_param ~ N(mu_mu, mu_cov) + Q ~ N(Q_mu, Q_cov) + + This has an analytic solution at the mode which we can compare against. + """ + rng = np.random.default_rng(12345) + n = 10000 + d = 10 + + # Initialise arrays mu_true = rng.random(d) cov_true = np.diag(rng.random(d)) Q_val = np.diag(rng.random(d)) - cov_param_val = rng.random(d**2).reshape((d, d)) + cov_param_val = np.diag(rng.random(d)) x_val = rng.random(d) mu_val = rng.random(d) mu_mu = rng.random(d) - mu_cov = np.diag(rng.random(d)) + mu_cov = np.diag(np.ones(d)) cov_mu = rng.random(d**2) - cov_cov = np.diag(rng.random(d**2)) + cov_cov = np.diag(np.ones(d**2)) Q_mu = rng.random(d**2) - Q_cov = np.diag(rng.random(d**2)) + Q_cov = np.diag(np.ones(d**2)) with pm.Model() as model: y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) @@ -309,6 +322,7 @@ def test_get_conditional_gaussian_approximation(): cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov) Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) + # Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) y = pm.MvNormal( @@ -330,30 +344,21 @@ def test_get_conditional_gaussian_approximation(): x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten() ) + # Get analytic values of the mode and Laplace-approximated log posterior cov_param_inv = np.linalg.inv(cov_param_val) - x0_true = np.linalg.inv(n * cov_param_inv - 2 * Q_val) @ ( - cov_param_inv @ y_obs.sum(axis=0) - 2 * Q_val @ mu_val + x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ ( + cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val ) - log_x_posterior_true = ( - -0.5 * x_val.T @ (-n * cov_param_inv + Q_val) @ x_val - + x_val.T - @ ( - Q_val @ mu_val - - cov_param_inv @ (y_obs - x0_true).sum(axis=0) - - n * cov_param_inv @ x0_true - ) + jac_true = cov_param_inv @ (y_obs - x0_true).sum(axis=0) - Q_val @ (x0_true - mu_val) + hess_true = -n * cov_param_inv - Q_val + + log_x_posterior_laplace_true = ( + -0.5 * x_val.T @ (-hess_true + Q_val) @ x_val + + x_val.T @ (Q_val @ mu_val + jac_true - hess_true @ x0_true) + 0.5 * np.log(np.linalg.det(Q_val)) - + -0.5 * cov_param_val.flatten().T @ np.linalg.inv(cov_cov) @ cov_param_val.flatten() - - ((d**2) / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(cov_cov)) - + -0.5 * mu_val.T @ np.linalg.inv(mu_cov) @ mu_val - - (d / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(mu_cov)) - + -0.5 * (x_val - mu_val).T @ np.linalg.inv(Q_val) @ (x_val - mu_val) - - (d / 2) * np.log(2 * np.pi) - - 0.5 * np.log(np.linalg.det(Q_val)) ) + np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) - np.testing.assert_allclose(log_x_posterior, log_x_posterior_true, atol=0.1, rtol=0.1) + np.testing.assert_allclose(log_x_posterior, log_x_posterior_laplace_true, atol=0.1, rtol=0.1) From 6b70a5fd9f86ce3e5696aed106fd0eca1e10c6ec Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Fri, 27 Jun 2025 16:00:02 +1000 Subject: [PATCH 14/16] refactor: unsqueeze H --- pymc_extras/inference/laplace.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 17a661792..7e5e52cb5 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -389,6 +389,8 @@ def fit_mvn_at_MAP( ) H = -f_hess(mu.data) + if H.ndim == 1: + H = np.expand_dims(H, axis=1) H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) def stabilize(x, jitter): From bfa1a1aec3bc76d18881478ea653bba81df5342f Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Fri, 27 Jun 2025 21:45:53 +1000 Subject: [PATCH 15/16] refactor: change comments for clarity --- pymc_extras/inference/laplace.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 7e5e52cb5..d0b0044ad 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -106,7 +106,7 @@ def get_conditional_gaussian_approximation( jac = pytensor.gradient.grad(f_x, x) hess = pytensor.gradient.jacobian(jac.flatten(), x) - # Component of log(p(x | y, params)) which depends on x + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) # Maximize log(p(x | y, params)) wrt x @@ -129,7 +129,8 @@ def get_conditional_gaussian_approximation( -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ ) - # TODO Currently x being passed in as an initial guess for x0 AND then also going to the true value of x + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is + # far from the mode x0 or in a neighbourhood which results in poor convergence. return pytensor.function(args, [x0, conditional_gaussian_approx]) From a266a2e930175a1d3405e557af7878d2b114e72d Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Fri, 27 Jun 2025 22:49:27 +1000 Subject: [PATCH 16/16] detailed docstring --- pymc_extras/inference/laplace.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index d0b0044ad..d64d2adab 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -68,7 +68,25 @@ def get_conditional_gaussian_approximation( optimizer_kwargs: dict | None = None, ) -> Callable: """ - Returns a function to estimate log(p(x | y, params)) and its mode x0 using the Laplace approximation. + Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. + + That is: + y | x, sigma ~ N(Ax, sigma^2 W) + x | params ~ N(mu, Q(params)^-1) + + We seek to estimate log(p(x | y, params)): + + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const + + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). + + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. + + Thus: + + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. + + 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). Parameters ---------- @@ -109,7 +127,7 @@ def get_conditional_gaussian_approximation( # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) - # Maximize log(p(x | y, params)) wrt x + # Maximize log(p(x | y, params)) wrt x to find mode x0 x0, _ = minimize( objective=-log_x_posterior, x=x, @@ -123,7 +141,7 @@ def get_conditional_gaussian_approximation( jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) - # Full log(p(x | y, params)) using Laplace approximation (up to a constant) + # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) _, logdetQ = pt.nlinalg.slogdet(Q) conditional_gaussian_approx = ( -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ