Skip to content

Commit be1d790

Browse files
Allow calling find_MAP inside model context without model argument
1 parent ad3abd9 commit be1d790

File tree

2 files changed

+27
-25
lines changed

2 files changed

+27
-25
lines changed

pymc_experimental/inference/jax_find_map.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,9 @@ def f_hess_jax(x):
338338

339339

340340
def find_MAP(
341-
model: pm.Model,
342341
method: minimize_method,
342+
*,
343+
model: pm.Model | None = None,
343344
use_grad: bool | None = None,
344345
use_hessp: bool | None = None,
345346
use_hess: bool | None = None,
@@ -357,7 +358,7 @@ def find_MAP(
357358
Parameters
358359
----------
359360
model : pm.Model
360-
The PyMC model to be fitted.
361+
The PyMC model to be fit. If None, the current model context is used.
361362
method : str
362363
The optimization method to use. See scipy.optimize.minimize documentation for details.
363364
use_grad : bool | None, optional
@@ -391,6 +392,7 @@ def find_MAP(
391392
Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True,
392393
also returns the object returned by ``scipy.optimize.minimize``.
393394
"""
395+
model = pm.modelcontext(model)
394396
frozen_model = freeze_dims_and_data(model)
395397

396398
if jitter_rvs is None:

tests/test_jax_find_map.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,20 @@ def compute_z(x):
6565
],
6666
)
6767
def test_JAX_map(method, use_grad, use_hess, rng):
68-
with pm.Model() as m:
69-
mu = pm.Normal("mu")
70-
sigma = pm.Exponential("sigma", 1)
71-
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))
72-
7368
extra_kwargs = {}
7469
if method == "dogleg":
7570
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
7671
# where this is true
7772
extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}}
7873

79-
optimized_point = find_MAP(
80-
m, method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False
81-
)
74+
with pm.Model() as m:
75+
mu = pm.Normal("mu")
76+
sigma = pm.Exponential("sigma", 1)
77+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100))
78+
79+
optimized_point = find_MAP(
80+
method=method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False
81+
)
8282
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
8383

8484
assert np.isclose(mu_hat, 3, atol=0.5)
@@ -102,12 +102,12 @@ def test_fit_laplace_coords(rng, transform_samples):
102102
observed=rng.normal(loc=3, scale=1.5, size=(100, 3)),
103103
dims=["obs_idx", "city"],
104104
)
105-
optimized_point = find_MAP(
106-
model,
107-
"Newton-CG",
108-
use_grad=True,
109-
progressbar=False,
110-
)
105+
106+
optimized_point = find_MAP(
107+
method="Newton-CG",
108+
use_grad=True,
109+
progressbar=False,
110+
)
111111

112112
for value in optimized_point.values():
113113
assert value.shape == (3,)
@@ -145,9 +145,9 @@ def test_fit_laplace_ragged_coords(rng):
145145
dims=["obs_idx", "city"],
146146
)
147147

148-
optimized_point, _ = find_MAP(
149-
ragged_dim_model, "Newton-CG", use_grad=True, progressbar=False, return_raw=True
150-
)
148+
optimized_point, _ = find_MAP(
149+
method="Newton-CG", use_grad=True, progressbar=False, return_raw=True
150+
)
151151

152152
idata = fit_laplace(optimized_point, ragged_dim_model, progressbar=False)
153153

@@ -176,12 +176,12 @@ def test_fit_laplace(transform_samples):
176176
observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)),
177177
)
178178

179-
optimized_point = find_MAP(
180-
simp_model,
181-
"Newton-CG",
182-
use_grad=True,
183-
progressbar=False,
184-
)
179+
optimized_point = find_MAP(
180+
method="Newton-CG",
181+
use_grad=True,
182+
progressbar=False,
183+
)
184+
185185
idata = fit_laplace(
186186
optimized_point, simp_model, transform_samples=transform_samples, progressbar=False
187187
)

0 commit comments

Comments
 (0)