Skip to content

Commit a1292ea

Browse files
WIP: Should find root of conditional_gaussian_approx not minimize nll
1 parent 40f27e0 commit a1292ea

File tree

2 files changed

+133
-72
lines changed

2 files changed

+133
-72
lines changed

pymc_extras/inference/laplace.py

Lines changed: 124 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import xarray as xr
3030

3131
from arviz import dict_to_dataset
32-
from better_optimize.constants import minimize_method
32+
from better_optimize.constants import minimize_method, root_method
3333
from pymc import DictToArrayBijection
3434
from pymc.backends.arviz import (
3535
coords_and_dims_for_inferencedata,
@@ -41,7 +41,7 @@
4141
from pymc.model.transform.optimization import freeze_dims_and_data
4242
from pymc.util import get_default_varnames
4343
from pytensor.tensor import TensorVariable
44-
from pytensor.tensor.optimize import minimize
44+
from pytensor.tensor.optimize import root
4545
from scipy import stats
4646

4747
from pymc_extras.inference.find_map import (
@@ -55,6 +55,128 @@
5555
_log = logging.getLogger(__name__)
5656

5757

58+
def find_mode_jac_hess(
59+
x: TensorVariable, # Should be vector specifically
60+
Q: TensorVariable, # Matrix # TODO tensorinv doesn't have grad implemented yet
61+
mu: TensorVariable, # Vector
62+
model: pm.Model | None = None,
63+
method: root_method = "hybr",
64+
use_jac: bool = True,
65+
# use_hess: bool = False,
66+
optimizer_kwargs: dict | None = None,
67+
) -> Callable:
68+
"""
69+
Returns a function to estimate the mode and both the first and second derivatives of a model at that point by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize.
70+
71+
Parameters
72+
----------
73+
x: TensorVariable
74+
The parameter with which to minimize wrt (that is, find the mode in x).
75+
model: Model
76+
PyMC model to use.
77+
method: minimize_method
78+
Which minimization algorithm to use.
79+
use_jac: bool
80+
If true, the minimizer will compute and store the Jacobian.
81+
use_hess: bool
82+
If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False).
83+
optimizer_kwargs: dict
84+
Kwargs to pass to scipy.optimize.minimize.
85+
86+
Returns
87+
-------
88+
f: Callable
89+
A function which accepts the values of the model RVs as args and returns [mu, jac(mu) hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args.
90+
"""
91+
model = pm.modelcontext(model)
92+
93+
# f = log(p(y | x, params))
94+
f = model.logp()
95+
jac = pytensor.gradient.grad(f, x)
96+
hess = pytensor.gradient.jacobian(jac.flatten(), x)
97+
98+
# Component of log(p(x | y, params)) which depends on x (for rootfinding)
99+
conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x)
100+
101+
x0, _ = root(
102+
equations=pt.stack([conditional_gaussian_approx]),
103+
variables=x,
104+
method=method,
105+
jac=use_jac,
106+
optimizer_kwargs=optimizer_kwargs,
107+
)
108+
109+
# require f'(x0) and f''(x0) for Laplace approx
110+
jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
111+
hess = pytensor.graph.replace.graph_replace(
112+
hess, {x: x0}
113+
) # Possibly unecessary because jac already does this replace
114+
115+
# Full log(p(x | y, params))
116+
_, logdetQ = pt.nlinalg.slogdet(Q)
117+
conditional_gaussian_approx = (
118+
-0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
119+
) # TODO does doing this change the graph in root before if changed before it's compiled?
120+
121+
args = model.continuous_value_vars + model.discrete_value_vars
122+
return pytensor.function(
123+
args, [x0, conditional_gaussian_approx]
124+
) # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x
125+
126+
# Minimise negative log likelihood
127+
# nll = -model.logp()
128+
# soln, _ = minimize(
129+
# objective=nll,
130+
# x=x,
131+
# method=method,
132+
# jac=use_jac,
133+
# hess=use_hess,
134+
# optimizer_kwargs=optimizer_kwargs,
135+
# )
136+
137+
# TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln:
138+
#
139+
# TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()).
140+
#
141+
# My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)?
142+
143+
# Obtain the Hessian (re-use graph if already computed in minimize)
144+
# if use_hess:
145+
# mode, _, hess = (
146+
# soln.owner.op.inner_outputs
147+
# ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging
148+
# hess = pytensor.graph.replace.graph_replace(hess, {mode: soln})
149+
# else:
150+
# hess = pytensor.gradient.hessian(nll, x)
151+
152+
# Obtain the gradient and Hessian (re-use graphs if already computed in minimize)
153+
# res = soln.owner.op.inner_outputs
154+
# mode = res[0]
155+
156+
# print(res)
157+
158+
# if use_jac:
159+
# # jac = pytensor.gradient.grad(nll, x)
160+
# jac = res.pop(1)
161+
# else:
162+
# jac = pytensor.gradient.grad(nll, x)
163+
# jac = pytensor.graph.replace.graph_replace(jac, {x: soln})
164+
165+
# print(x)
166+
# # jac = pytensor.graph.replace.graph_replace(jac, {x: soln})
167+
168+
# jac = -jac # We subsequently want the gradients wrt log(p(y | x)) rather than the negative of this (nll)
169+
170+
# if use_hess:
171+
# hess = res.pop(1)
172+
# else:
173+
# hess = pytensor.gradient.jacobian(jac.flatten(), soln)
174+
# # hess = pytensor.graph.replace.graph_replace(hess, {x: soln})
175+
176+
# args = model.continuous_value_vars + model.discrete_value_vars
177+
# return pytensor.function(args, [soln, jac, hess])
178+
179+
58180
def laplace_draws_to_inferencedata(
59181
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
60182
) -> az.InferenceData:
@@ -418,69 +540,6 @@ def sample_laplace_posterior(
418540
return idata
419541

420542

421-
def find_mode_and_hess(
422-
x: TensorVariable,
423-
model: pm.Model | None = None,
424-
method: minimize_method = "BFGS",
425-
use_jac: bool = True,
426-
use_hess: bool = False, # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine.
427-
optimizer_kwargs: dict | None = None,
428-
) -> Callable:
429-
"""
430-
Returns a function to estimate the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize.
431-
432-
Parameters
433-
----------
434-
x: TensorVariable
435-
The parameter with which to minimize wrt (that is, find the mode in x).
436-
model: Model
437-
PyMC model to use.
438-
method: minimize_method
439-
Which minimization algorithm to use.
440-
use_jac: bool
441-
If true, the minimizer will compute and store the Jacobian.
442-
use_hess: bool
443-
If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False).
444-
optimizer_kwargs: dict
445-
Kwargs to pass to scipy.optimize.minimize.
446-
447-
Returns
448-
-------
449-
f: Callable
450-
A function which accepts the values of the model RVs as args and returns [mu, hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args.
451-
"""
452-
model = pm.modelcontext(model)
453-
454-
# Minimise negative log likelihood
455-
nll = -model.logp()
456-
soln, _ = minimize(
457-
objective=nll,
458-
x=x,
459-
method=method,
460-
jac=use_jac,
461-
hess=use_hess,
462-
optimizer_kwargs=optimizer_kwargs,
463-
)
464-
465-
# TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln:
466-
#
467-
# TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()).
468-
#
469-
# My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)?
470-
471-
# Obtain the Hessian (re-use graph if already computed in minimize)
472-
if use_hess:
473-
mode, _, hess = (
474-
soln.owner.op.inner_outputs
475-
) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging
476-
hess = pytensor.graph.replace.graph_replace(hess, {mode: soln})
477-
else:
478-
hess = pytensor.gradient.hessian(nll, x)
479-
480-
args = model.continuous_value_vars + model.discrete_value_vars
481-
return pytensor.function(args, [soln, hess])
482-
483-
484543
def fit_laplace(
485544
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
486545
*,

tests/test_laplace.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pymc_extras.inference.find_map import GradientBackend, find_MAP
2323
from pymc_extras.inference.laplace import (
24-
find_mode_and_hess,
24+
find_mode_jac_hess,
2525
fit_laplace,
2626
fit_mvn_at_MAP,
2727
sample_laplace_posterior,
@@ -282,17 +282,19 @@ def test_laplace_scalar():
282282
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
283283

284284

285-
def test_find_mode_and_hess():
285+
def test_find_mode_jac_hess():
286286
rng = np.random.default_rng(42)
287287
n = 100
288288
sigma_obs = rng.random()
289289
sigma_mu = rng.random()
290+
true_mu = rng.random()
291+
mu_val = rng.random()
290292

291293
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(n)}
292294
with pm.Model(coords=coords) as model:
293-
obs_val = rng.normal(loc=3, scale=1.5, size=(n, 3))
295+
obs_val = rng.normal(loc=true_mu, scale=1.5, size=(n, 3))
294296

295-
mu = pm.Normal("mu", mu=1, sigma=sigma_mu, dims=["city"])
297+
mu = pm.Normal("mu", mu=mu_val, sigma=sigma_mu, dims=["city"])
296298
obs = pm.Normal(
297299
"obs",
298300
mu=mu,
@@ -301,14 +303,14 @@ def test_find_mode_and_hess():
301303
dims=["obs_idx", "city"],
302304
)
303305

304-
get_mode_and_hessian = find_mode_and_hess(
306+
get_mode_and_hessian = find_mode_jac_hess(
305307
use_hess=False, x=model.rvs_to_values[mu], method="BFGS", optimizer_kwargs={"tol": 1e-8}
306308
)
307309

308-
mode, hess = get_mode_and_hessian(**{"mu": [1, 1, 1]})
310+
mode, jac, hess = get_mode_and_hessian(mu=[1, 1, 1])
309311

310312
true_mode = obs_val.mean(axis=0)
311-
true_hess = np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3))
313+
true_hess = -np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3))
312314

313315
np.testing.assert_allclose(mode, true_mode, atol=0.1, rtol=0.1)
314316
np.testing.assert_allclose(hess, true_hess, atol=0.1, rtol=0.1)

0 commit comments

Comments
 (0)