1
- from collections import defaultdict
2
- from typing import Tuple , Optional
3
-
4
- import pymc
5
- from pymc import Model
6
1
import arviz as az
7
2
import numpy as np
8
- from scipy . optimize import minimize
3
+ import pymc
9
4
import pytensor
10
5
import pytensor .tensor as pt
11
- from pytensor .tensor .variable import TensorVariable
12
6
import xarray
13
7
14
- from pymc import join_nonshared_inputs , DictToArrayBijection
15
- from pymc .util import get_default_varnames , RandomSeed
8
+ from pymc import DictToArrayBijection , Model , join_nonshared_inputs
16
9
from pymc .backends .arviz import (
17
- apply_function_over_dataset ,
18
10
PointFunc ,
11
+ apply_function_over_dataset ,
19
12
coords_and_dims_for_inferencedata ,
20
13
)
14
+ from pymc .util import RandomSeed , get_default_varnames
15
+ from pytensor .tensor .variable import TensorVariable
16
+ from scipy .optimize import minimize
17
+
18
+ from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
21
19
from pymc_extras .inference .laplace_approx .scipy_interface import (
22
20
_compile_functions_for_scipy_optimize ,
23
21
)
24
- from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
25
22
26
23
27
24
def fit_deterministic_advi (
28
- model : Optional [ Model ] = None ,
25
+ model : Model | None = None ,
29
26
n_fixed_draws : int = 30 ,
30
27
random_seed : RandomSeed = None ,
31
28
n_draws : int = 1000 ,
@@ -93,9 +90,7 @@ def fit_deterministic_advi(
93
90
compute_hess = False ,
94
91
)
95
92
96
- result = minimize (
97
- f_fused , np .zeros (2 * n_params ), method = "trust-ncg" , jac = True , hessp = f_hessp
98
- )
93
+ result = minimize (f_fused , np .zeros (2 * n_params ), method = "trust-ncg" , jac = True , hessp = f_hessp )
99
94
100
95
opt_var_params = result .x
101
96
opt_means , opt_log_sds = np .split (opt_var_params , 2 )
@@ -107,9 +102,7 @@ def fit_deterministic_advi(
107
102
draws = opt_means + draws_raw * np .exp (opt_log_sds )
108
103
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
109
104
110
- transformed_draws = transform_draws (
111
- draws_arviz , model , keep_untransformed = keep_untransformed
112
- )
105
+ transformed_draws = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
113
106
114
107
return transformed_draws
115
108
@@ -119,7 +112,7 @@ def create_dadvi_graph(
119
112
n_params : int ,
120
113
n_fixed_draws : int = 30 ,
121
114
random_seed : RandomSeed = None ,
122
- ) -> Tuple [TensorVariable , TensorVariable ]:
115
+ ) -> tuple [TensorVariable , TensorVariable ]:
123
116
"""
124
117
Sets up the DADVI graph in pytensor and returns it.
125
118
@@ -165,9 +158,7 @@ def create_dadvi_graph(
165
158
draw_matrix = pt .constant (draws )
166
159
samples = means + pt .exp (log_sds ) * draw_matrix
167
160
168
- logp_vectorized_draws = pytensor .graph .vectorize_graph (
169
- logp , replace = {flat_input : samples }
170
- )
161
+ logp_vectorized_draws = pytensor .graph .vectorize_graph (logp , replace = {flat_input : samples })
171
162
172
163
mean_log_density = pt .mean (logp_vectorized_draws )
173
164
entropy = pt .sum (log_sds )
0 commit comments