Skip to content

Commit dac0096

Browse files
added test case, removed inputs as required arg
1 parent 23b4970 commit dac0096

File tree

2 files changed

+213
-22
lines changed

2 files changed

+213
-22
lines changed

pymc_extras/inference/laplace.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -418,43 +418,76 @@ def sample_laplace_posterior(
418418

419419

420420
def find_mode(
421-
inputs: list[TensorVariable],
422-
params: dict, # TODO Would be nice to automatically map this to inputs somehow: {k.name: ... for k in inputs}
421+
x: TensorVariable,
422+
args: dict,
423+
inputs: list[TensorVariable] | None = None,
423424
x0: TensorVariable
424425
| None = None, # TODO This isn't a TensorVariable, not sure what the general datatype for numeric arraylikes is
425-
x: TensorVariable | None = None,
426426
model: pm.Model | None = None,
427427
method: minimize_method = "BFGS",
428-
jac: bool = True,
429-
hess: bool = False,
428+
use_jac: bool = True,
429+
use_hess: bool = False,
430430
optimizer_kwargs: dict | None = None,
431431
): # TODO Output type is list of same type as x0
432432
model = pm.modelcontext(model)
433-
if x is None:
434-
raise UserWarning(
435-
"Latent Gaussian field x unspecified. Assuming it is the first entry in inputs. Specify which input to obtain the mode over using the input x."
436-
)
437-
x = inputs[0]
438433

439-
if x0 is None:
440-
# Should return a random numpy array of the same shape as x0 - not sure how to get the shape of x0
441-
raise NotImplementedError
434+
# if x0 is None:
435+
# #TODO Issue with X not being an RV
436+
# print(model.initial_point())
437+
438+
# from pymc.initial_point import make_initial_point_fn
439+
# frozen_model = freeze_dims_and_data(model)
440+
# ipfn = make_initial_point_fn(
441+
# model=model,
442+
# jitter_rvs=set(),#(jitter_rvs),
443+
# return_transformed=True,
444+
# overrides=args,
445+
# )
446+
447+
# random_seed = None
448+
# start_dict = ipfn(random_seed)
449+
# vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
450+
# initial_params = DictToArrayBijection.map(
451+
# {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
452+
# )
453+
# print(initial_params)
442454

443455
# Minimise negative log likelihood
444456
nll = -model.logp()
445457
soln, _ = minimize(
446-
objective=nll, x=x, method=method, jac=jac, hess=hess, optimizer_kwargs=optimizer_kwargs
458+
objective=nll,
459+
x=x,
460+
method=method,
461+
jac=use_jac,
462+
hess=use_hess,
463+
optimizer_kwargs=optimizer_kwargs,
447464
)
448465

449-
get_mode = pytensor.function(inputs, soln)
450-
mode = get_mode(x0, **params)
451-
452-
# Calculate the value of the Hessian at the mode
453-
# TODO check if we can't pull this out of the soln graph when jac or hess=True
454-
hess_x = pytensor.gradient.hessian(nll, x)
455-
hess = pytensor.function(inputs, hess_x)
466+
# Get input variables
467+
# TODO issue when this is nll
468+
if inputs is None:
469+
inputs = [
470+
pytensor.graph.basic.get_var_by_name(model.basic_RVs[1], target_var_id=var)[0]
471+
for var in args
472+
]
473+
for i, var in enumerate(inputs):
474+
try:
475+
inputs[i] = model.rvs_to_values[var]
476+
except KeyError:
477+
pass
478+
inputs.insert(0, x)
479+
480+
# Obtain the Hessian (re-use graph if already computed in minimize)
481+
if use_hess:
482+
hess = soln.owner.op.inner_outputs[-1]
483+
hess = pytensor.graph.replace.graph_replace(
484+
hess, {x: soln}
485+
) # TODO: x here is 'beta', soln is a MinimizeOp. There's no instance of MinimizeOp in the hessian graph
486+
else:
487+
hess = pytensor.gradient.hessian(nll, x)
456488

457-
return mode, hess(mode, **params)
489+
get_mode_and_hessian = pytensor.function(inputs, [soln, hess])
490+
return get_mode_and_hessian(x0, **args)
458491

459492

460493
def fit_laplace(

tests/test_laplace.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import pymc as pm
18+
import pytensor as pt
1819
import pytest
1920

2021
import pymc_extras as pmx
@@ -279,3 +280,160 @@ def test_laplace_scalar():
279280
assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
280281

281282
np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
283+
284+
285+
def test_find_mode():
286+
k = 10
287+
N = 10000
288+
y = pt.vector("y", dtype="int64")
289+
X = pt.matrix("X", shape=(N, k))
290+
291+
# Pre-commit did this. Quite ugly. Should compute hess in code rather than storing a hardcoded array.
292+
true_hess = np.array(
293+
[
294+
[
295+
2.50100000e03,
296+
-1.78838742e00,
297+
1.59484217e01,
298+
-9.78343803e00,
299+
2.86125467e01,
300+
-7.38071788e00,
301+
-4.97729126e01,
302+
3.53243810e01,
303+
1.69071769e01,
304+
-1.30755942e01,
305+
],
306+
[
307+
-1.78838742e00,
308+
2.54687995e03,
309+
8.99456512e-02,
310+
-1.33603390e01,
311+
-2.37641179e01,
312+
4.57780742e01,
313+
-1.22640681e01,
314+
2.70879664e01,
315+
4.04435512e01,
316+
2.08826556e00,
317+
],
318+
[
319+
1.59484217e01,
320+
8.99456512e-02,
321+
2.46908384e03,
322+
-1.80358232e01,
323+
1.14131535e01,
324+
2.21632317e01,
325+
1.25443469e00,
326+
1.50344618e01,
327+
-3.59940488e01,
328+
-1.05191328e01,
329+
],
330+
[
331+
-9.78343803e00,
332+
-1.33603390e01,
333+
-1.80358232e01,
334+
2.50546496e03,
335+
3.27545028e01,
336+
-3.33517501e01,
337+
-2.68735672e01,
338+
-2.69114305e01,
339+
-1.20464337e01,
340+
9.02338622e00,
341+
],
342+
[
343+
2.86125467e01,
344+
-2.37641179e01,
345+
1.14131535e01,
346+
3.27545028e01,
347+
2.49959736e03,
348+
-3.98220135e00,
349+
-4.09495199e00,
350+
-1.51115257e01,
351+
-5.77436126e01,
352+
-2.98600447e00,
353+
],
354+
[
355+
-7.38071788e00,
356+
4.57780742e01,
357+
2.21632317e01,
358+
-3.33517501e01,
359+
-3.98220135e00,
360+
2.48169432e03,
361+
-1.26885014e01,
362+
-3.53524089e01,
363+
5.89656794e00,
364+
1.67164400e01,
365+
],
366+
[
367+
-4.97729126e01,
368+
-1.22640681e01,
369+
1.25443469e00,
370+
-2.68735672e01,
371+
-4.09495199e00,
372+
-1.26885014e01,
373+
2.47216241e03,
374+
8.16935659e00,
375+
-4.89399152e01,
376+
-1.11646138e01,
377+
],
378+
[
379+
3.53243810e01,
380+
2.70879664e01,
381+
1.50344618e01,
382+
-2.69114305e01,
383+
-1.51115257e01,
384+
-3.53524089e01,
385+
8.16935659e00,
386+
2.52940405e03,
387+
3.07751540e00,
388+
-8.60023392e00,
389+
],
390+
[
391+
1.69071769e01,
392+
4.04435512e01,
393+
-3.59940488e01,
394+
-1.20464337e01,
395+
-5.77436126e01,
396+
5.89656794e00,
397+
-4.89399152e01,
398+
3.07751540e00,
399+
2.49452594e03,
400+
6.06984410e01,
401+
],
402+
[
403+
-1.30755942e01,
404+
2.08826556e00,
405+
-1.05191328e01,
406+
9.02338622e00,
407+
-2.98600447e00,
408+
1.67164400e01,
409+
-1.11646138e01,
410+
-8.60023392e00,
411+
6.06984410e01,
412+
2.49290175e03,
413+
],
414+
]
415+
)
416+
417+
with pm.Model() as model:
418+
beta = pm.MvNormal("beta", mu=np.zeros(k), cov=np.identity(k), shape=(k,))
419+
p = pm.math.invlogit(beta @ X.T)
420+
y = pm.Bernoulli("y", p)
421+
422+
rng = np.random.default_rng(123)
423+
Xval = rng.normal(size=(10000, 9))
424+
Xval = np.c_[np.ones(10000), Xval]
425+
426+
true_beta = rng.normal(scale=0.1, size=(10,))
427+
true_p = pm.math.invlogit(Xval @ true_beta).eval()
428+
ynum = rng.binomial(1, true_p)
429+
430+
beta_val = model.rvs_to_values[beta]
431+
x0 = np.zeros(k)
432+
args = {"y": ynum, "X": Xval}
433+
434+
beta_mode, beta_hess = pmx.inference.laplace.find_mode(
435+
x=beta_val, x0=x0, args=args, method="BFGS", optimizer_kwargs={"tol": 1e-8}
436+
)
437+
438+
np.testing.assert_allclose(beta_mode, true_beta, atol=0.1, rtol=0.1)
439+
np.testing.assert_allclose(beta_hess, true_hess, atol=0.1, rtol=0.1)

0 commit comments

Comments
 (0)