Skip to content

Commit ad46b07

Browse files
author
Martin Ingram
committed
Implement suggestions
1 parent 3fcafb6 commit ad46b07

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

pymc_extras/inference/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pymc_extras.inference.deterministic_advi.dadvi import fit_deterministic_advi
15+
from pymc_extras.inference.dadvi.dadvi import fit_dadvi
1616
from pymc_extras.inference.fit import fit
1717
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1818
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
@@ -23,5 +23,5 @@
2323
"fit",
2424
"fit_laplace",
2525
"fit_pathfinder",
26-
"fit_deterministic_advi",
26+
"fit_dadvi",
2727
]
File renamed without changes.

pymc_extras/inference/deterministic_advi/dadvi.py renamed to pymc_extras/inference/dadvi/dadvi.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as pt
66
import xarray
77

8+
from better_optimize.constants import minimize_method
89
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
910
from pymc.backends.arviz import (
1011
PointFunc,
@@ -21,16 +22,18 @@
2122
)
2223

2324

24-
def fit_deterministic_advi(
25+
def fit_dadvi(
2526
model: Model | None = None,
2627
n_fixed_draws: int = 30,
2728
random_seed: RandomSeed = None,
2829
n_draws: int = 1000,
2930
keep_untransformed: bool = False,
31+
method: minimize_method = "trust-ncg",
32+
**minimize_kwargs,
3033
) -> az.InferenceData:
3134
"""
3235
Does inference using deterministic ADVI (automatic differentiation
33-
variational inference).
36+
variational inference), DADVI for short.
3437
3538
For full details see the paper cited in the references:
3639
https://www.jmlr.org/papers/v25/23-1015.html
@@ -57,6 +60,19 @@ def fit_deterministic_advi(
5760
Whether or not to keep the unconstrained variables (such as
5861
logs of positive-constrained parameters) in the output.
5962
63+
method: str
64+
Which optimization method to use. The function calls
65+
``scipy.optimize.minimize``, so any of the methods there can
66+
be used. The default is trust-ncg, which uses second-order
67+
information and is generally very reliable. Other methods such
68+
as L-BFGS-B might be faster but potentially more brittle and
69+
may not converge exactly to the optimum.
70+
71+
minimize_kwargs:
72+
Additional keyword arguments to pass to the
73+
``scipy.optimize.minimize`` function. See the documentation of
74+
that function for details.
75+
6076
Returns
6177
-------
6278
:class:`~arviz.InferenceData`
@@ -90,7 +106,14 @@ def fit_deterministic_advi(
90106
compute_hess=False,
91107
)
92108

93-
result = minimize(f_fused, np.zeros(2 * n_params), method="trust-ncg", jac=True, hessp=f_hessp)
109+
result = minimize(
110+
f_fused,
111+
np.zeros(2 * n_params),
112+
method=method,
113+
jac=True,
114+
hessp=f_hessp,
115+
**minimize_kwargs,
116+
)
94117

95118
opt_var_params = result.x
96119
opt_means, opt_log_sds = np.split(opt_var_params, 2)
@@ -151,8 +174,7 @@ def create_dadvi_graph(
151174
)
152175

153176
var_params = pt.vector(name="eta", shape=(2 * n_params,))
154-
155-
means , log_sds= pt.split(var_params, 2)
177+
means, log_sds = var_params[:n_params], var_params[n_params:]
156178

157179
draw_matrix = pt.constant(draws)
158180
samples = means + pt.exp(log_sds) * draw_matrix

pymc_extras/inference/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ def fit(method: str, **kwargs) -> az.InferenceData:
4242
return fit_laplace(**kwargs)
4343

4444
if method == "deterministic_advi":
45-
from pymc_extras.inference import fit_deterministic_advi
45+
from pymc_extras.inference import fit_dadvi
4646

47-
return fit_deterministic_advi(**kwargs)
47+
return fit_dadvi(**kwargs)

0 commit comments

Comments
 (0)