Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions pymc_extras/inference/laplace_approx/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,20 @@ def find_MAP(
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
}

idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed)
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
idata = map_results_to_inference_data(
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
)

idata = add_fit_to_inference_data(
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
)

idata = add_optimizer_result_to_inference_data(
idata, optimizer_result, method, raveled_optimized, model
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
)

idata = add_data_to_inference_data(
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
)

return idata
16 changes: 16 additions & 0 deletions tests/inference/laplace_approx/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def test_find_MAP(
assert not hasattr(idata, "unconstrained_posterior")


def test_find_map_outside_model_context():
"""
Test that find_MAP can be called outside of a model context.
"""
with pm.Model() as m:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Exponential("sigma", 1)
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))

idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False)

assert hasattr(idata, "posterior")
assert hasattr(idata, "fit")
assert hasattr(idata, "optimizer_result")
Comment on lines +199 to +201

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it relevant / worth checking the presence of any data group? I see above a call like the following within find_MAP

    idata = add_data_to_inference_data(
        idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
    )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crap yes it is, but I just clicked merge

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, well another thing I just saw:

Is it necessary to compute deterministics within add_data_to_inference_data? See

if model.deterministics:
expand_dims = {}
if "chain" not in idata.posterior.coords:
expand_dims["chain"] = [0]
if "draw" not in idata.posterior.coords:
expand_dims["draw"] = [0]
idata.posterior = pm.compute_deterministics(
idata.posterior.expand_dims(expand_dims),
model=model,
merge_dataset=True,
progressbar=progressbar,
compile_kwargs=compile_kwargs,
)

I guess that function is not exposed to the user, but I just wanted to raise that potentially silent side effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not necessary, but that's what pm.sample does right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see now why it's included



@pytest.mark.parametrize(
"backend, gradient_backend",
[("jax", "jax")],
Expand Down
22 changes: 22 additions & 0 deletions tests/inference/laplace_approx/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)


def test_fit_laplace_outside_model_context():
with pm.Model() as m:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Exponential("sigma", 1)
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))

idata = fit_laplace(
model=m,
optimize_method="L-BFGS-B",
use_grad=True,
progressbar=False,
chains=1,
draws=100,
)

assert hasattr(idata, "posterior")
assert hasattr(idata, "fit")
assert hasattr(idata, "optimizer_result")

assert idata.posterior["mu"].shape == (1, 100)


@pytest.mark.parametrize(
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
)
Expand Down
Loading