Skip to content

Commit 19bc44d

Browse files
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginalisation-routine
2 parents b3a3351 + 07c6ab4 commit 19bc44d

File tree

15 files changed

+4975
-25
lines changed

15 files changed

+4975
-25
lines changed

notebooks/DFM_Example_(Coincident_Index).ipynb

Lines changed: 2107 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/deterministic_advi_example.ipynb

Lines changed: 975 additions & 0 deletions
Large diffs are not rendered by default.

pymc_extras/inference/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from pymc_extras.inference.dadvi.dadvi import fit_dadvi
1516
from pymc_extras.inference.fit import fit
1617
from pymc_extras.inference.INLA.inla import fit_INLA
1718
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1819
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1920
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
2021

21-
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]
22+
__all__ = [
23+
"find_MAP",
24+
"fit",
25+
"fit_laplace",
26+
"fit_pathfinder",
27+
"fit_dadvi",
28+
"fit_INLA"
29+
]

pymc_extras/inference/dadvi/__init__.py

Whitespace-only changes.
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc
4+
import pytensor
5+
import pytensor.tensor as pt
6+
import xarray
7+
8+
from better_optimize import minimize
9+
from better_optimize.constants import minimize_method
10+
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11+
from pymc.backends.arviz import (
12+
PointFunc,
13+
apply_function_over_dataset,
14+
coords_and_dims_for_inferencedata,
15+
)
16+
from pymc.util import RandomSeed, get_default_varnames
17+
from pytensor.tensor.variable import TensorVariable
18+
19+
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
20+
from pymc_extras.inference.laplace_approx.scipy_interface import (
21+
_compile_functions_for_scipy_optimize,
22+
)
23+
24+
25+
def fit_dadvi(
26+
model: Model | None = None,
27+
n_fixed_draws: int = 30,
28+
random_seed: RandomSeed = None,
29+
n_draws: int = 1000,
30+
keep_untransformed: bool = False,
31+
optimizer_method: minimize_method = "trust-ncg",
32+
use_grad: bool = True,
33+
use_hessp: bool = True,
34+
use_hess: bool = False,
35+
**minimize_kwargs,
36+
) -> az.InferenceData:
37+
"""
38+
Does inference using deterministic ADVI (automatic differentiation
39+
variational inference), DADVI for short.
40+
41+
For full details see the paper cited in the references:
42+
https://www.jmlr.org/papers/v25/23-1015.html
43+
44+
Parameters
45+
----------
46+
model : pm.Model
47+
The PyMC model to be fit. If None, the current model context is used.
48+
49+
n_fixed_draws : int
50+
The number of fixed draws to use for the optimisation. More
51+
draws will result in more accurate estimates, but also
52+
increase inference time. Usually, the default of 30 is a good
53+
tradeoff.between speed and accuracy.
54+
55+
random_seed: int
56+
The random seed to use for the fixed draws. Running the optimisation
57+
twice with the same seed should arrive at the same result.
58+
59+
n_draws: int
60+
The number of draws to return from the variational approximation.
61+
62+
keep_untransformed: bool
63+
Whether or not to keep the unconstrained variables (such as
64+
logs of positive-constrained parameters) in the output.
65+
66+
optimizer_method: str
67+
Which optimization method to use. The function calls
68+
``scipy.optimize.minimize``, so any of the methods there can
69+
be used. The default is trust-ncg, which uses second-order
70+
information and is generally very reliable. Other methods such
71+
as L-BFGS-B might be faster but potentially more brittle and
72+
may not converge exactly to the optimum.
73+
74+
minimize_kwargs:
75+
Additional keyword arguments to pass to the
76+
``scipy.optimize.minimize`` function. See the documentation of
77+
that function for details.
78+
79+
use_grad:
80+
If True, pass the gradient function to
81+
`scipy.optimize.minimize` (where it is referred to as `jac`).
82+
83+
use_hessp:
84+
If True, pass the hessian vector product to `scipy.optimize.minimize`.
85+
86+
use_hess:
87+
If True, pass the hessian to `scipy.optimize.minimize`. Note that
88+
this is generally not recommended since its computation can be slow
89+
and memory-intensive if there are many parameters.
90+
91+
Returns
92+
-------
93+
:class:`~arviz.InferenceData`
94+
The inference data containing the results of the DADVI algorithm.
95+
96+
References
97+
----------
98+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99+
Variational Inference with a Deterministic Objective: Faster, More
100+
Accurate, and Even More Black Box. Journal of Machine Learning
101+
Research, 25(18), 1–39.
102+
"""
103+
104+
model = pymc.modelcontext(model) if model is None else model
105+
106+
initial_point_dict = model.initial_point()
107+
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
108+
109+
var_params, objective = create_dadvi_graph(
110+
model,
111+
n_fixed_draws=n_fixed_draws,
112+
random_seed=random_seed,
113+
n_params=n_params,
114+
)
115+
116+
f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117+
objective,
118+
[var_params],
119+
compute_grad=use_grad,
120+
compute_hessp=use_hessp,
121+
compute_hess=use_hess,
122+
)
123+
124+
derivative_kwargs = {}
125+
126+
if use_grad:
127+
derivative_kwargs["jac"] = True
128+
if use_hessp:
129+
derivative_kwargs["hessp"] = f_hessp
130+
if use_hess:
131+
derivative_kwargs["hess"] = True
132+
133+
result = minimize(
134+
f_fused,
135+
np.zeros(2 * n_params),
136+
method=optimizer_method,
137+
**derivative_kwargs,
138+
**minimize_kwargs,
139+
)
140+
141+
opt_var_params = result.x
142+
opt_means, opt_log_sds = np.split(opt_var_params, 2)
143+
144+
# Make the draws:
145+
generator = np.random.default_rng(seed=random_seed)
146+
draws_raw = generator.standard_normal(size=(n_draws, n_params))
147+
148+
draws = opt_means + draws_raw * np.exp(opt_log_sds)
149+
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150+
151+
transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
152+
153+
return transformed_draws
154+
155+
156+
def create_dadvi_graph(
157+
model: Model,
158+
n_params: int,
159+
n_fixed_draws: int = 30,
160+
random_seed: RandomSeed = None,
161+
) -> tuple[TensorVariable, TensorVariable]:
162+
"""
163+
Sets up the DADVI graph in pytensor and returns it.
164+
165+
Parameters
166+
----------
167+
model : pm.Model
168+
The PyMC model to be fit.
169+
170+
n_params: int
171+
The total number of parameters in the model.
172+
173+
n_fixed_draws : int
174+
The number of fixed draws to use.
175+
176+
random_seed: int
177+
The random seed to use for the fixed draws.
178+
179+
Returns
180+
-------
181+
Tuple[TensorVariable, TensorVariable]
182+
A tuple whose first element contains the variational parameters,
183+
and whose second contains the DADVI objective.
184+
"""
185+
186+
# Make the fixed draws
187+
generator = np.random.default_rng(seed=random_seed)
188+
draws = generator.standard_normal(size=(n_fixed_draws, n_params))
189+
190+
inputs = model.continuous_value_vars + model.discrete_value_vars
191+
initial_point_dict = model.initial_point()
192+
logp = model.logp()
193+
194+
# Graph in terms of a flat input
195+
[logp], flat_input = join_nonshared_inputs(
196+
point=initial_point_dict, outputs=[logp], inputs=inputs
197+
)
198+
199+
var_params = pt.vector(name="eta", shape=(2 * n_params,))
200+
201+
means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
202+
203+
draw_matrix = pt.constant(draws)
204+
samples = means + pt.exp(log_sds) * draw_matrix
205+
206+
logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
207+
208+
mean_log_density = pt.mean(logp_vectorized_draws)
209+
entropy = pt.sum(log_sds)
210+
211+
objective = -mean_log_density - entropy
212+
213+
return var_params, objective
214+
215+
216+
def transform_draws(
217+
unstacked_draws: xarray.Dataset,
218+
model: Model,
219+
keep_untransformed: bool = False,
220+
):
221+
"""
222+
Transforms the unconstrained draws back into the constrained space.
223+
224+
Parameters
225+
----------
226+
unstacked_draws : xarray.Dataset
227+
The draws to constrain back into the original space.
228+
229+
model : Model
230+
The PyMC model the variables were derived from.
231+
232+
n_draws: int
233+
The number of draws to return from the variational approximation.
234+
235+
keep_untransformed: bool
236+
Whether or not to keep the unconstrained variables in the output.
237+
238+
Returns
239+
-------
240+
:class:`~arviz.InferenceData`
241+
Draws from the original constrained parameters.
242+
"""
243+
244+
filtered_var_names = model.unobserved_value_vars
245+
vars_to_sample = list(
246+
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
247+
)
248+
fn = pytensor.function(model.value_vars, vars_to_sample)
249+
point_func = PointFunc(fn)
250+
251+
coords, dims = coords_and_dims_for_inferencedata(model)
252+
253+
transformed_result = apply_function_over_dataset(
254+
point_func,
255+
unstacked_draws,
256+
output_var_names=[x.name for x in vars_to_sample],
257+
coords=coords,
258+
dims=dims,
259+
)
260+
261+
return transformed_result

pymc_extras/inference/fit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def fit(method: str, **kwargs) -> az.InferenceData:
4545
from pymc_extras.inference.INLA import fit_INLA
4646

4747
return fit_INLA(**kwargs)
48+
49+
elif method == "dadvi":
50+
from pymc_extras.inference import fit_dadvi
51+
52+
return fit_dadvi(**kwargs)
4853

4954
else:
5055
raise ValueError(

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def find_MAP(
198198
include_transformed: bool = True,
199199
gradient_backend: GradientBackend = "pytensor",
200200
compile_kwargs: dict | None = None,
201+
compute_hessian: bool = False,
201202
**optimizer_kwargs,
202203
) -> (
203204
dict[str, np.ndarray]
@@ -239,6 +240,10 @@ def find_MAP(
239240
Whether to include transformed variable values in the returned dictionary. Defaults to True.
240241
gradient_backend: str, default "pytensor"
241242
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243+
compute_hessian: bool
244+
If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
245+
InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
246+
high-dimensional problems. Defaults to False.
242247
compile_kwargs: dict, optional
243248
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
244249
**optimizer_kwargs
@@ -316,14 +321,17 @@ def find_MAP(
316321
**optimizer_kwargs,
317322
)
318323

319-
H_inv = _compute_inverse_hessian(
320-
optimizer_result=optimizer_result,
321-
optimal_point=None,
322-
f_fused=f_fused,
323-
f_hessp=f_hessp,
324-
use_hess=use_hess,
325-
method=method,
326-
)
324+
if compute_hessian:
325+
H_inv = _compute_inverse_hessian(
326+
optimizer_result=optimizer_result,
327+
optimal_point=None,
328+
f_fused=f_fused,
329+
f_hessp=f_hessp,
330+
use_hess=use_hess,
331+
method=method,
332+
)
333+
else:
334+
H_inv = None
327335

328336
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329337
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def map_results_to_inference_data(
136136

137137

138138
def add_fit_to_inference_data(
139-
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
139+
idata: az.InferenceData,
140+
mu: RaveledVars,
141+
H_inv: np.ndarray | None,
142+
model: pm.Model | None = None,
140143
) -> az.InferenceData:
141144
"""
142145
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
@@ -147,7 +150,7 @@ def add_fit_to_inference_data(
147150
An InferenceData object containing the approximated posterior samples.
148151
mu: RaveledVars
149152
The MAP estimate of the model parameters.
150-
H_inv: np.ndarray
153+
H_inv: np.ndarray, optional
151154
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
152155
model: Model, optional
153156
A PyMC model. If None, the model is taken from the current model context.

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def fit_laplace(
288288
include_transformed=include_transformed,
289289
gradient_backend=gradient_backend,
290290
compile_kwargs=compile_kwargs,
291+
compute_hessian=True,
291292
**optimizer_kwargs,
292293
)
293294

0 commit comments

Comments
 (0)