Skip to content

Commit bdee446

Browse files
author
Martin Ingram
committed
Fix with pre commit checks
1 parent a8a53f3 commit bdee446

File tree

3 files changed

+21
-30
lines changed

3 files changed

+21
-30
lines changed

notebooks/deterministic_advi_example.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"rng = np.random.default_rng(RANDOM_SEED)\n",
6666
"\n",
6767
"%config InlineBackend.figure_format = 'retina'\n",
68-
"az.style.use(\"arviz-darkgrid\")\n"
68+
"az.style.use(\"arviz-darkgrid\")"
6969
]
7070
},
7171
{
@@ -916,14 +916,14 @@
916916
"\n",
917917
"f, ax = plt.subplots(3, 1)\n",
918918
"\n",
919-
"sns.kdeplot(idata.posterior.Intercept.values.reshape(-1), ax=ax[0], label='MCMC')\n",
920-
"sns.kdeplot(dadvi_res.Intercept.values.reshape(-1), ax=ax[0], label='DADVI')\n",
919+
"sns.kdeplot(idata.posterior.Intercept.values.reshape(-1), ax=ax[0], label=\"MCMC\")\n",
920+
"sns.kdeplot(dadvi_res.Intercept.values.reshape(-1), ax=ax[0], label=\"DADVI\")\n",
921921
"\n",
922-
"sns.kdeplot(idata.posterior.slope.values.reshape(-1), ax=ax[1], label='MCMC')\n",
923-
"sns.kdeplot(dadvi_res.slope.values.reshape(-1), ax=ax[1], label='DADVI')\n",
922+
"sns.kdeplot(idata.posterior.slope.values.reshape(-1), ax=ax[1], label=\"MCMC\")\n",
923+
"sns.kdeplot(dadvi_res.slope.values.reshape(-1), ax=ax[1], label=\"DADVI\")\n",
924924
"\n",
925-
"sns.kdeplot(idata.posterior.sigma.values.reshape(-1), ax=ax[2], label='MCMC')\n",
926-
"sns.kdeplot(dadvi_res.sigma.values.reshape(-1), ax=ax[2], label='DADVI')\n",
925+
"sns.kdeplot(idata.posterior.sigma.values.reshape(-1), ax=ax[2], label=\"MCMC\")\n",
926+
"sns.kdeplot(dadvi_res.sigma.values.reshape(-1), ax=ax[2], label=\"DADVI\")\n",
927927
"\n",
928928
"for cur_ax in ax:\n",
929929
" cur_ax.legend()\n",

pymc_extras/inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from pymc_extras.inference.deterministic_advi.dadvi import fit_deterministic_advi
1516
from pymc_extras.inference.fit import fit
1617
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1718
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1819
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19-
from pymc_extras.inference.deterministic_advi.dadvi import fit_deterministic_advi
2020

2121
__all__ = [
2222
"find_MAP",

pymc_extras/inference/deterministic_advi/dadvi.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
1-
from collections import defaultdict
2-
from typing import Tuple, Optional
3-
4-
import pymc
5-
from pymc import Model
61
import arviz as az
72
import numpy as np
8-
from scipy.optimize import minimize
3+
import pymc
94
import pytensor
105
import pytensor.tensor as pt
11-
from pytensor.tensor.variable import TensorVariable
126
import xarray
137

14-
from pymc import join_nonshared_inputs, DictToArrayBijection
15-
from pymc.util import get_default_varnames, RandomSeed
8+
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
169
from pymc.backends.arviz import (
17-
apply_function_over_dataset,
1810
PointFunc,
11+
apply_function_over_dataset,
1912
coords_and_dims_for_inferencedata,
2013
)
14+
from pymc.util import RandomSeed, get_default_varnames
15+
from pytensor.tensor.variable import TensorVariable
16+
from scipy.optimize import minimize
17+
18+
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
2119
from pymc_extras.inference.laplace_approx.scipy_interface import (
2220
_compile_functions_for_scipy_optimize,
2321
)
24-
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
2522

2623

2724
def fit_deterministic_advi(
28-
model: Optional[Model] = None,
25+
model: Model | None = None,
2926
n_fixed_draws: int = 30,
3027
random_seed: RandomSeed = None,
3128
n_draws: int = 1000,
@@ -93,9 +90,7 @@ def fit_deterministic_advi(
9390
compute_hess=False,
9491
)
9592

96-
result = minimize(
97-
f_fused, np.zeros(2 * n_params), method="trust-ncg", jac=True, hessp=f_hessp
98-
)
93+
result = minimize(f_fused, np.zeros(2 * n_params), method="trust-ncg", jac=True, hessp=f_hessp)
9994

10095
opt_var_params = result.x
10196
opt_means, opt_log_sds = np.split(opt_var_params, 2)
@@ -107,9 +102,7 @@ def fit_deterministic_advi(
107102
draws = opt_means + draws_raw * np.exp(opt_log_sds)
108103
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
109104

110-
transformed_draws = transform_draws(
111-
draws_arviz, model, keep_untransformed=keep_untransformed
112-
)
105+
transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
113106

114107
return transformed_draws
115108

@@ -119,7 +112,7 @@ def create_dadvi_graph(
119112
n_params: int,
120113
n_fixed_draws: int = 30,
121114
random_seed: RandomSeed = None,
122-
) -> Tuple[TensorVariable, TensorVariable]:
115+
) -> tuple[TensorVariable, TensorVariable]:
123116
"""
124117
Sets up the DADVI graph in pytensor and returns it.
125118
@@ -165,9 +158,7 @@ def create_dadvi_graph(
165158
draw_matrix = pt.constant(draws)
166159
samples = means + pt.exp(log_sds) * draw_matrix
167160

168-
logp_vectorized_draws = pytensor.graph.vectorize_graph(
169-
logp, replace={flat_input: samples}
170-
)
161+
logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
171162

172163
mean_log_density = pt.mean(logp_vectorized_draws)
173164
entropy = pt.sum(log_sds)

0 commit comments

Comments
 (0)