Skip to content

Commit b8d4048

Browse files
Allow basinhopping when fitting DADVI
1 parent b8164ff commit b8d4048

File tree

2 files changed

+96
-16
lines changed

2 files changed

+96
-16
lines changed

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytensor.tensor as pt
66
import xarray
77

8-
from better_optimize import minimize
8+
from better_optimize import basinhopping, minimize
99
from better_optimize.constants import minimize_method
1010
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
1111
from pymc.backends.arviz import (
@@ -31,7 +31,6 @@
3131
def fit_dadvi(
3232
model: Model | None = None,
3333
n_fixed_draws: int = 30,
34-
random_seed: RandomSeed = None,
3534
n_draws: int = 1000,
3635
include_transformed: bool = False,
3736
optimizer_method: minimize_method = "trust-ncg",
@@ -40,7 +39,9 @@ def fit_dadvi(
4039
use_hess: bool | None = None,
4140
gradient_backend: str = "pytensor",
4241
compile_kwargs: dict | None = None,
43-
**minimize_kwargs,
42+
random_seed: RandomSeed = None,
43+
progressbar: bool = True,
44+
**optimizer_kwargs,
4445
) -> az.InferenceData:
4546
"""
4647
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
@@ -79,10 +80,6 @@ def fit_dadvi(
7980
compile_kwargs: dict, optional
8081
Additional keyword arguments to pass to `pytensor.function`
8182
82-
minimize_kwargs:
83-
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
84-
that function for details.
85-
8683
use_grad: bool, optional
8784
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
8885
@@ -93,6 +90,13 @@ def fit_dadvi(
9390
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
9491
computation can be slow and memory-intensive if there are many parameters.
9592
93+
progressbar: bool
94+
Whether or not to show a progress bar during optimization. Default is True.
95+
96+
optimizer_kwargs:
97+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98+
that function for details.
99+
96100
Returns
97101
-------
98102
:class:`~arviz.InferenceData`
@@ -105,6 +109,16 @@ def fit_dadvi(
105109
"""
106110

107111
model = pymc.modelcontext(model) if model is None else model
112+
do_basinhopping = optimizer_method == "basinhopping"
113+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
114+
115+
if do_basinhopping:
116+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118+
# if one isn't provided.
119+
120+
optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
121+
minimizer_kwargs["method"] = optimizer_method
108122

109123
initial_point_dict = model.initial_point()
110124
initial_point = DictToArrayBijection.map(initial_point_dict)
@@ -145,14 +159,34 @@ def fit_dadvi(
145159
)
146160

147161
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
148-
149-
result = minimize(
150-
f=f_fused,
151-
x0=dadvi_initial_point.data,
152-
method=optimizer_method,
153-
hessp=f_hessp,
154-
**minimize_kwargs,
155-
)
162+
args = optimizer_kwargs.pop("args", ())
163+
164+
if do_basinhopping:
165+
if "args" not in minimizer_kwargs:
166+
minimizer_kwargs["args"] = args
167+
if "hessp" not in minimizer_kwargs:
168+
minimizer_kwargs["hessp"] = f_hessp
169+
if "method" not in minimizer_kwargs:
170+
minimizer_kwargs["method"] = optimizer_method
171+
172+
result = basinhopping(
173+
func=f_fused,
174+
x0=dadvi_initial_point.data,
175+
progressbar=progressbar,
176+
minimizer_kwargs=minimizer_kwargs,
177+
**optimizer_kwargs,
178+
)
179+
180+
else:
181+
result = minimize(
182+
f=f_fused,
183+
x0=dadvi_initial_point.data,
184+
args=args,
185+
method=optimizer_method,
186+
hessp=f_hessp,
187+
progressbar=progressbar,
188+
**optimizer_kwargs,
189+
)
156190

157191
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
158192

@@ -166,7 +200,9 @@ def fit_dadvi(
166200
draws = opt_means + draws_raw * np.exp(opt_log_sds)
167201
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
168202

169-
idata = dadvi_result_to_idata(draws_arviz, model, include_transformed=include_transformed)
203+
idata = dadvi_result_to_idata(
204+
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
205+
)
170206

171207
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
172208
var_name_to_model_var.update(
@@ -253,6 +289,7 @@ def dadvi_result_to_idata(
253289
unstacked_draws: xarray.Dataset,
254290
model: Model,
255291
include_transformed: bool = False,
292+
progressbar: bool = True,
256293
):
257294
"""
258295
Transforms the unconstrained draws back into the constrained space.
@@ -271,6 +308,9 @@ def dadvi_result_to_idata(
271308
include_transformed: bool
272309
Whether or not to keep the unconstrained variables in the output.
273310
311+
progressbar: bool
312+
Whether or not to show a progress bar during the transformation. Default is True.
313+
274314
Returns
275315
-------
276316
:class:`~arviz.InferenceData`
@@ -292,6 +332,7 @@ def dadvi_result_to_idata(
292332
output_var_names=[x.name for x in vars_to_sample],
293333
coords=coords,
294334
dims=dims,
335+
progressbar=progressbar,
295336
)
296337

297338
constrained_names = [

tests/inference/dadvi/test_dadvi.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,42 @@ def test_fit_dadvi_ragged_coords(rng):
140140
# strictly positive
141141
assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all()
142142
assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()
143+
144+
145+
@pytest.mark.parametrize(
146+
"method, use_grad, use_hess, use_hessp",
147+
[
148+
("Newton-CG", True, True, False),
149+
("Newton-CG", True, False, True),
150+
],
151+
)
152+
def test_dadvi_basinhopping(method, use_grad, use_hess, use_hessp, rng):
153+
pytest.importorskip("jax")
154+
155+
with pm.Model() as m:
156+
mu = pm.Normal("mu")
157+
sigma = pm.Exponential("sigma", 1)
158+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=10))
159+
160+
idata = fit_dadvi(
161+
optimizer_method="basinhopping",
162+
use_grad=use_grad,
163+
use_hess=use_hess,
164+
use_hessp=use_hessp,
165+
progressbar=False,
166+
include_transformed=True,
167+
minimizer_kwargs=dict(method=method),
168+
niter=1,
169+
n_draws=100,
170+
)
171+
172+
assert hasattr(idata, "posterior")
173+
assert hasattr(idata, "unconstrained_posterior")
174+
175+
posterior = idata.posterior
176+
unconstrained_posterior = idata.unconstrained_posterior
177+
assert "mu" in posterior
178+
assert posterior["mu"].shape == (1, 100)
179+
180+
assert "sigma_log__" in unconstrained_posterior
181+
assert unconstrained_posterior["sigma_log__"].shape == (1, 100)

0 commit comments

Comments
 (0)