Skip to content

Commit 03a5e36

Browse files
jessegrabowskiandreacate
authored andcommitted
Update and Refactor find_MAP and fit_laplace (pymc-devs#531)
* Move laplace and find_map to submodule * Split idata utilities into `idata.py` * Refactor find_MAP * Refactor fit_laplace * Update better-optimize version pin * Handle labeling of non-scalar RVs without dims * Add unconstrained posterior draws/points to unconstrained_posterior
1 parent d93c64d commit 03a5e36

File tree

10 files changed

+21
-1836
lines changed

10 files changed

+21
-1836
lines changed

pymc_extras/inference/find_map.py

Lines changed: 0 additions & 496 deletions
This file was deleted.

pymc_extras/inference/laplace.py

Lines changed: 0 additions & 685 deletions
This file was deleted.

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def find_MAP(
326326
)
327327

328328
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329-
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
329+
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
330330
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
331331
DictToArrayBijection.rmap(raveled_optimized)
332332
)
@@ -335,20 +335,13 @@ def find_MAP(
335335
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
336336
}
337337

338-
idata = map_results_to_inference_data(
339-
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
340-
)
341-
342-
idata = add_fit_to_inference_data(
343-
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
344-
)
345-
338+
idata = map_results_to_inference_data(optimized_point, frozen_model)
339+
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
346340
idata = add_optimizer_result_to_inference_data(
347-
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
341+
idata, optimizer_result, method, raveled_optimized, model
348342
)
349-
350343
idata = add_data_to_inference_data(
351-
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
344+
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
352345
)
353346

354347
return idata

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
5959
def map_results_to_inference_data(
6060
map_point: dict[str, float | int | np.ndarray],
6161
model: pm.Model | None = None,
62-
include_transformed: bool = True,
6362
):
6463
"""
6564
Add the MAP point to an InferenceData object in the posterior group.
@@ -69,13 +68,13 @@ def map_results_to_inference_data(
6968
7069
Parameters
7170
----------
71+
idata: az.InferenceData
72+
An InferenceData object to which the MAP point will be added.
7273
map_point: dict
7374
A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
7475
the values should be the corresponding MAP estimates.
7576
model: Model, optional
7677
A PyMC model. If None, the model is taken from the current model context.
77-
include_transformed: bool
78-
Whether to return transformed (unconstrained) variables in the constrained_posterior group. Default is True.
7978
8079
Returns
8180
-------
@@ -119,7 +118,7 @@ def map_results_to_inference_data(
119118
dims=dims,
120119
)
121120

122-
if unconstrained_names and include_transformed:
121+
if unconstrained_names:
123122
unconstrained_posterior = az.from_dict(
124123
posterior={
125124
k: np.expand_dims(v, (0, 1))

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def fit_laplace(
302302
----------
303303
model : pm.Model
304304
The PyMC model to be fit. If None, the current model context is used.
305-
optimize_method : str
305+
method : str
306306
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
307307
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
308308
@@ -441,11 +441,9 @@ def fit_laplace(
441441
.rename({"temp_chain": "chain", "temp_draw": "draw"})
442442
)
443443

444-
if include_transformed:
445-
idata.unconstrained_posterior = unstack_laplace_draws(
446-
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
447-
)
448-
444+
idata.unconstrained_posterior = unstack_laplace_draws(
445+
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
446+
)
449447
idata.posterior = new_posterior.drop_vars(
450448
["laplace_approximation", "unpacked_variable_names"]
451449
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dynamic = ["version"] # specify the version in the __init__.py file
3636
dependencies = [
3737
"pymc>=5.21.1",
3838
"scikit-learn",
39-
"better-optimize>=0.1.2",
39+
"better-optimize>=0.1.4",
4040
"pydantic>=2.0.0",
4141
]
4242

tests/inference/laplace_approx/test_find_map.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,12 @@ def compute_z(x):
133133
],
134134
)
135135
@pytest.mark.parametrize(
136-
"backend, gradient_backend, include_transformed",
137-
[("jax", "jax", True), ("jax", "pytensor", False)],
136+
"backend, gradient_backend",
137+
[("jax", "jax"), ("jax", "pytensor")],
138138
ids=str,
139139
)
140140
def test_find_MAP(
141-
method,
142-
use_grad,
143-
use_hess,
144-
use_hessp,
145-
backend,
146-
gradient_backend: GradientBackend,
147-
include_transformed,
148-
rng,
141+
method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng
149142
):
150143
pytest.importorskip("jax")
151144

@@ -161,12 +154,12 @@ def test_find_MAP(
161154
use_hessp=use_hessp,
162155
progressbar=False,
163156
gradient_backend=gradient_backend,
164-
include_transformed=include_transformed,
165157
compile_kwargs={"mode": backend.upper()},
166158
maxiter=5,
167159
)
168160

169161
assert hasattr(idata, "posterior")
162+
assert hasattr(idata, "unconstrained_posterior")
170163
assert hasattr(idata, "fit")
171164
assert hasattr(idata, "optimizer_result")
172165
assert hasattr(idata, "observed_data")
@@ -176,29 +169,9 @@ def test_find_MAP(
176169
assert posterior["mu"].shape == ()
177170
assert posterior["sigma"].shape == ()
178171

179-
if include_transformed:
180-
assert hasattr(idata, "unconstrained_posterior")
181-
unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"])
182-
assert "sigma_log__" in unconstrained_posterior
183-
assert unconstrained_posterior["sigma_log__"].shape == ()
184-
else:
185-
assert not hasattr(idata, "unconstrained_posterior")
186-
187-
188-
def test_find_map_outside_model_context():
189-
"""
190-
Test that find_MAP can be called outside of a model context.
191-
"""
192-
with pm.Model() as m:
193-
mu = pm.Normal("mu", 0, 1)
194-
sigma = pm.Exponential("sigma", 1)
195-
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
196-
197-
idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False)
198-
199-
assert hasattr(idata, "posterior")
200-
assert hasattr(idata, "fit")
201-
assert hasattr(idata, "optimizer_result")
172+
unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"])
173+
assert "sigma_log__" in unconstrained_posterior
174+
assert unconstrained_posterior["sigma_log__"].shape == ()
202175

203176

204177
@pytest.mark.parametrize(

tests/inference/laplace_approx/test_laplace.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,7 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)
8484

8585

86-
def test_fit_laplace_outside_model_context():
87-
with pm.Model() as m:
88-
mu = pm.Normal("mu", 0, 1)
89-
sigma = pm.Exponential("sigma", 1)
90-
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
91-
92-
idata = fit_laplace(
93-
model=m,
94-
optimize_method="L-BFGS-B",
95-
use_grad=True,
96-
progressbar=False,
97-
chains=1,
98-
draws=100,
99-
)
100-
101-
assert hasattr(idata, "posterior")
102-
assert hasattr(idata, "fit")
103-
assert hasattr(idata, "optimizer_result")
104-
105-
assert idata.posterior["mu"].shape == (1, 100)
106-
107-
108-
@pytest.mark.parametrize(
109-
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
110-
)
111-
def test_fit_laplace_coords(include_transformed, rng):
86+
def test_fit_laplace_coords(rng):
11287
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
11388
with pm.Model(coords=coords) as model:
11489
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -127,7 +102,6 @@ def test_fit_laplace_coords(include_transformed, rng):
127102
chains=1,
128103
draws=1000,
129104
optimizer_kwargs=dict(tol=1e-20),
130-
include_transformed=include_transformed,
131105
)
132106

133107
np.testing.assert_allclose(
@@ -146,11 +120,6 @@ def test_fit_laplace_coords(include_transformed, rng):
146120
"sigma_log__[C]",
147121
]
148122

149-
assert hasattr(idata, "unconstrained_posterior") == include_transformed
150-
if include_transformed:
151-
assert "sigma_log__" in idata.unconstrained_posterior
152-
assert "city" in idata.unconstrained_posterior.coords
153-
154123

155124
def test_fit_laplace_ragged_coords(rng):
156125
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
@@ -230,50 +199,6 @@ def test_model_with_nonstandard_dimensionality(rng):
230199
assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys())
231200

232201

233-
def test_laplace_nonstandard_dims_2d():
234-
true_P = np.array([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3], [0.2, 0.4, 0.4]])
235-
y_obs = pm.draw(
236-
pmx.DiscreteMarkovChain.dist(
237-
P=true_P,
238-
init_dist=pm.Categorical.dist(
239-
logit_p=np.ones(
240-
3,
241-
)
242-
),
243-
shape=(100, 5),
244-
)
245-
)
246-
247-
with pm.Model(
248-
coords={
249-
"time": range(y_obs.shape[0]),
250-
"state": list("ABC"),
251-
"next_state": list("ABC"),
252-
"unit": [1, 2, 3, 4, 5],
253-
}
254-
) as model:
255-
y = pm.Data("y", y_obs, dims=["time", "unit"])
256-
init_dist = pm.Categorical.dist(
257-
logit_p=np.ones(
258-
3,
259-
)
260-
)
261-
P = pm.Dirichlet("P", a=np.eye(3) * 2 + 1, dims=["state", "next_state"])
262-
y_hat = pmx.DiscreteMarkovChain(
263-
"y_hat", P=P, init_dist=init_dist, dims=["time", "unit"], observed=y_obs
264-
)
265-
266-
idata = pmx.fit_laplace(progressbar=True)
267-
268-
# The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
269-
assert "state" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
270-
271-
# The mutated dimension should be unknown coords
272-
assert "P_simplex___dim_1" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
273-
274-
assert idata.unconstrained_posterior.P_simplex__.shape[-2:] == (3, 2)
275-
276-
277202
def test_laplace_nonscalar_rv_without_dims():
278203
with pm.Model(coords={"test": ["A", "B", "C"]}) as model:
279204
x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"])

0 commit comments

Comments
 (0)