Skip to content

Commit 637fc3b

Browse files
author
Martin Ingram
committed
Add to API and add a docstring
1 parent a1afaf6 commit 637fc3b

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

pymc_extras/inference/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1717
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1818
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19+
from pymc_extras.inference.deterministic_advi.api import fit_deterministic_advi
1920

20-
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
21+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_deterministic_advi"]

pymc_extras/inference/deterministic_advi/api.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,39 @@ def compute_function_on_mean_field_draws(
9393

9494

9595
def fit_deterministic_advi(model=None, num_fixed_draws=30, seed=2):
96+
"""
97+
Does inference using deterministic ADVI (automatic differentiation
98+
variational inference).
99+
100+
For full details see the paper cited in the references:
101+
https://www.jmlr.org/papers/v25/23-1015.html
102+
103+
Parameters
104+
----------
105+
model : pm.Model
106+
The PyMC model to be fit. If None, the current model context is used.
107+
108+
num_fixed_draws : int
109+
The number of fixed draws to use for the optimisation. More
110+
draws will result in more accurate estimates, but also
111+
increase inference time. Usually, the default of 30 is a good
112+
tradeoff.between speed and accuracy.
113+
114+
seed: int
115+
The random seed to use for the fixed draws. Running the optimisation
116+
twice with the same seed should arrive at the same result.
117+
118+
Returns
119+
-------
120+
:class:`~arviz.InferenceData`
121+
The inference data containing the results of the DADVI algorithm.
122+
123+
References
124+
----------
125+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective: Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
126+
127+
128+
"""
96129

97130
model = pymc.modelcontext(model) if model is None else model
98131

pymc_extras/inference/fit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,7 @@ def fit(method: str, **kwargs) -> az.InferenceData:
4141

4242
return fit_laplace(**kwargs)
4343

44-
# TODO Add determinstic ADVI
44+
if method == "deterministic_advi":
45+
from pymc_extras.inference import fit_deterministic_advi
46+
47+
return fit_deterministic_advi(**kwargs)

0 commit comments

Comments
 (0)