Skip to content

Commit f46f1cd

Browse files
author
Martin Ingram
committed
Update API
1 parent 488bd9c commit f46f1cd

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

pymc_extras/inference/deterministic_advi/api.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def get_posterior_draws_mean_field(
6969
if transform_draws:
7070

7171
dadvi_draws = transform_dadvi_draws(
72-
self.pymc_model, dadvi_draws_flat, self.unflattening_fun
72+
self.pymc_model,
73+
dadvi_draws_flat,
74+
self.unflattening_fun,
75+
add_chain_dim=True,
7376
)
7477

7578
else:
@@ -89,10 +92,13 @@ def compute_function_on_mean_field_draws(
8992
return vmap(function_to_run)(dadvi_dict)
9093

9194

92-
def fit_pymc_dadvi_with_jax(pymc_model, num_fixed_draws=30, seed=2):
95+
def fit_deterministic_advi(model=None, num_fixed_draws=30, seed=2):
96+
97+
model = pymc.modelcontext(model) if model is None else model
98+
9399
np.random.seed(seed)
94100

95-
jax_funs = get_jax_functions_from_pymc(pymc_model)
101+
jax_funs = get_jax_functions_from_pymc(model)
96102
dadvi_funs = build_dadvi_funs(jax_funs["log_posterior_fun"])
97103

98104
opt_callback_fun.opt_sequence = []
@@ -114,12 +120,11 @@ def fit_pymc_dadvi_with_jax(pymc_model, num_fixed_draws=30, seed=2):
114120
var_params=opt["opt_result"].x,
115121
unflattening_fun=jax_funs["unflatten_fun"],
116122
dadvi_funs=dadvi_funs,
117-
pymc_model=pymc_model,
123+
pymc_model=model,
118124
)
119125

120126
# Get draws and turn into arviz format expected
121127
draws = dadvi_result.get_posterior_draws_mean_field(transform_draws=True)
122-
123128
az_draws = az.convert_to_inference_data(draws)
124129

125130
return az_draws

pymc_extras/inference/fit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,5 @@ def fit(method: str, **kwargs) -> az.InferenceData:
4040
from pymc_extras.inference import fit_laplace
4141

4242
return fit_laplace(**kwargs)
43+
44+
# TODO Add determinstic ADVI

0 commit comments

Comments
 (0)